From 40c4e5b3d2b8244c4d779b4eaefd354b8a8b5f9a Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sat, 11 May 2024 19:00:34 +0800 Subject: [PATCH] breaking: remove multi-task support in tf (#3763) ## Summary by CodeRabbit - **New Features** - Removed the `multi_task` parameter across various descriptor initialization methods, streamlining the setup process. - Introduced a new option `--head` for specifying a model branch to freeze in multi-task mode. - **Bug Fixes** - Corrected initialization and training processes by removing outdated multi-task functionalities. - **Documentation** - Updated guides on model freezing and training to reflect the removal of multi-task functionalities and the shift towards using the PyTorch backend. - **Refactor** - Eliminated redundant code and simplified parameter assignments in training scripts. - **Chores** - Removed unused dictionaries and outdated code across several modules to clean up the codebase. --------- Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- deepmd/dpmodel/descriptor/dpa1.py | 2 +- deepmd/dpmodel/descriptor/se_e2_a.py | 4 +- deepmd/dpmodel/descriptor/se_r.py | 4 +- deepmd/main.py | 6 - deepmd/tf/descriptor/hybrid.py | 4 +- deepmd/tf/descriptor/se_a.py | 35 +- deepmd/tf/descriptor/se_a_ebd_v2.py | 2 - deepmd/tf/descriptor/se_atten.py | 28 +- deepmd/tf/descriptor/se_atten_v2.py | 4 - deepmd/tf/descriptor/se_r.py | 13 +- deepmd/tf/descriptor/se_t.py | 33 +- deepmd/tf/entrypoints/freeze.py | 216 +----- deepmd/tf/entrypoints/train.py | 67 +- deepmd/tf/env.py | 26 - deepmd/tf/model/__init__.py | 4 - deepmd/tf/model/multi.py | 677 ------------------ deepmd/tf/train/trainer.py | 553 +++----------- deepmd/tf/utils/multi_init.py | 170 ----- deepmd/utils/argcheck.py | 256 ------- doc/freeze/freeze.md | 15 +- doc/train/multi-task-training-tf.md | 161 +---- .../water_multi_task/ener_dipole/input.json | 135 ---- source/tests/common/test_examples.py | 1 - source/tests/tf/test_init_frz_model_multi.py | 254 ------- source/tests/tf/test_layer_name.py | 150 ---- source/tests/tf/test_model_multi.py | 264 ------- source/tests/tf/test_nvnmd_entrypoints.py | 6 + source/tests/tf/water_layer_name.json | 105 --- source/tests/tf/water_multi.json | 103 --- 29 files changed, 183 insertions(+), 3115 deletions(-) delete mode 100644 deepmd/tf/model/multi.py delete mode 100644 deepmd/tf/utils/multi_init.py delete mode 100644 examples/water_multi_task/ener_dipole/input.json delete mode 100644 source/tests/tf/test_init_frz_model_multi.py delete mode 100644 source/tests/tf/test_layer_name.py delete mode 100644 source/tests/tf/test_model_multi.py delete mode 100644 source/tests/tf/water_layer_name.json delete mode 100644 source/tests/tf/water_multi.json diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 14f4851023..a239412416 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -245,7 +245,7 @@ def __init__( # consistent with argcheck, not used though seed: Optional[int] = None, ) -> None: - ## seed, uniform_seed, multi_task, not included. + ## seed, uniform_seed, not included. # Ensure compatibility with the deprecated stripped_type_embedding option. if stripped_type_embedding is not None: # Use the user-set stripped_type_embedding parameter first diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index adc1913e96..193383ac4f 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -119,8 +119,6 @@ class DescrptSeA(NativeOP, BaseDescriptor): The activation function in the embedding net. Supported options are |ACTIVATION_FN| precision The precision of the embedding net parameters. Supported options are |PRECISION| - multi_task - If the model has multi fitting nets to train. spin The deepspin object. @@ -159,7 +157,7 @@ def __init__( # consistent with argcheck, not used though seed: Optional[int] = None, ) -> None: - ## seed, uniform_seed, multi_task, not included. + ## seed, uniform_seed, not included. if spin is not None: raise NotImplementedError("spin is not implemented") diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index ad802d5b25..5175b91ae1 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -75,8 +75,6 @@ class DescrptSeR(NativeOP, BaseDescriptor): The activation function in the embedding net. Supported options are |ACTIVATION_FN| precision The precision of the embedding net parameters. Supported options are |PRECISION| - multi_task - If the model has multi fitting nets to train. spin The deepspin object. @@ -114,7 +112,7 @@ def __init__( # consistent with argcheck, not used though seed: Optional[int] = None, ) -> None: - ## seed, uniform_seed, multi_task, not included. + ## seed, uniform_seed, not included. if not type_one_side: raise NotImplementedError("type_one_side == False not implemented") if spin is not None: diff --git a/deepmd/main.py b/deepmd/main.py index e9e7b0fcad..eab23ddb3f 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -323,12 +323,6 @@ def main_parser() -> argparse.ArgumentParser: default=None, help="(Supported backend: TensorFlow) the name of weight file (.npy), if set, save the model's weight into the file", ) - parser_frz.add_argument( - "--united-model", - action="store_true", - default=False, - help="(Supported backend: TensorFlow) When in multi-task mode, freeze all nodes into one united model", - ) parser_frz.add_argument( "--head", default=None, diff --git a/deepmd/tf/descriptor/hybrid.py b/deepmd/tf/descriptor/hybrid.py index 4e7eaa2c92..7c69ea5202 100644 --- a/deepmd/tf/descriptor/hybrid.py +++ b/deepmd/tf/descriptor/hybrid.py @@ -46,7 +46,6 @@ class DescrptHybrid(Descriptor): def __init__( self, list: List[Union[Descriptor, Dict[str, Any]]], - multi_task: bool = False, ntypes: Optional[int] = None, spin: Optional[Spin] = None, **kwargs, @@ -59,13 +58,12 @@ def __init__( "cannot build descriptor from an empty list of descriptors." ) formatted_descript_list = [] - self.multi_task = multi_task for ii in descrpt_list: if isinstance(ii, Descriptor): formatted_descript_list.append(ii) elif isinstance(ii, dict): formatted_descript_list.append( - Descriptor(**ii, ntypes=ntypes, spin=spin, multi_task=multi_task) + Descriptor(**ii, ntypes=ntypes, spin=spin) ) else: raise NotImplementedError diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 6684e21522..51c79f36af 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -152,8 +152,6 @@ class DescrptSeA(DescrptSe): The precision of the embedding net parameters. Supported options are |PRECISION| uniform_seed Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed - multi_task - If the model has multi fitting nets to train. env_protection: float Protection parameter to prevent division by zero errors during environment matrix calculations. @@ -181,7 +179,6 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - multi_task: bool = False, spin: Optional[Spin] = None, tebd_input_mode: str = "concat", env_protection: float = 0.0, # not implement!! @@ -304,15 +301,6 @@ def __init__( self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt)) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) self.original_sel = None - self.multi_task = multi_task - if multi_task: - self.stat_dict = { - "sumr": [], - "suma": [], - "sumn": [], - "sumr2": [], - "suma2": [], - } def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -392,21 +380,14 @@ def compute_input_stats( sumn.append(sysn) sumr2.append(sysr2) suma2.append(sysa2) - if not self.multi_task: - stat_dict = { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - self.merge_input_stats(stat_dict) - else: - self.stat_dict["sumr"] += sumr - self.stat_dict["suma"] += suma - self.stat_dict["sumn"] += sumn - self.stat_dict["sumr2"] += sumr2 - self.stat_dict["suma2"] += suma2 + stat_dict = { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } + self.merge_input_stats(stat_dict) def merge_input_stats(self, stat_dict): """Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd. diff --git a/deepmd/tf/descriptor/se_a_ebd_v2.py b/deepmd/tf/descriptor/se_a_ebd_v2.py index 9b92931b7f..9afa6598d1 100644 --- a/deepmd/tf/descriptor/se_a_ebd_v2.py +++ b/deepmd/tf/descriptor/se_a_ebd_v2.py @@ -43,7 +43,6 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - multi_task: bool = False, spin: Optional[Spin] = None, **kwargs, ) -> None: @@ -63,7 +62,6 @@ def __init__( activation_function=activation_function, precision=precision, uniform_seed=uniform_seed, - multi_task=multi_task, spin=spin, tebd_input_mode="strip", **kwargs, diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index dcf3f3c24a..8cbc0ab689 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -142,8 +142,6 @@ class DescrptSeAtten(DescrptSeA): Whether to mask the diagonal in the attention weights. ln_eps: float, Optional The epsilon value for layer normalization. - multi_task: bool - If the model has multi fitting nets to train. tebd_input_mode: str The input mode of the type embedding. Supported modes are ["concat", "strip"]. - "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network. @@ -188,7 +186,6 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - multi_task: bool = False, smooth_type_embedding: bool = False, tebd_input_mode: str = "concat", # not implemented @@ -246,7 +243,6 @@ def __init__( activation_function=activation_function, precision=precision, uniform_seed=uniform_seed, - multi_task=multi_task, ) """ Constructor @@ -403,21 +399,14 @@ def compute_input_stats( sumn.append(sysn) sumr2.append(sysr2) suma2.append(sysa2) - if not self.multi_task: - stat_dict = { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - self.merge_input_stats(stat_dict) - else: - self.stat_dict["sumr"] += sumr - self.stat_dict["suma"] += suma - self.stat_dict["sumn"] += sumn - self.stat_dict["sumr2"] += sumr2 - self.stat_dict["suma2"] += suma2 + stat_dict = { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } + self.merge_input_stats(stat_dict) def enable_compression( self, @@ -2117,7 +2106,6 @@ def __init__( attn_layer=attn_layer, attn_dotr=attn_dotr, attn_mask=attn_mask, - multi_task=True, trainable_ln=trainable_ln, ln_eps=ln_eps, smooth_type_embedding=smooth_type_embedding, diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index 61e672788e..6204f27855 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -59,8 +59,6 @@ class DescrptSeAttenV2(DescrptSeAtten): Whether to dot the relative coordinates on the attention weights as a gated scheme. attn_mask Whether to mask the diagonal in the attention weights. - multi_task - If the model has multi fitting nets to train. """ def __init__( @@ -84,7 +82,6 @@ def __init__( attn_layer: int = 2, attn_dotr: bool = True, attn_mask: bool = False, - multi_task: bool = False, **kwargs, ) -> None: DescrptSeAtten.__init__( @@ -108,7 +105,6 @@ def __init__( attn_layer=attn_layer, attn_dotr=attn_dotr, attn_mask=attn_mask, - multi_task=multi_task, tebd_input_mode="strip", smooth_type_embedding=True, **kwargs, diff --git a/deepmd/tf/descriptor/se_r.py b/deepmd/tf/descriptor/se_r.py index 64a599716f..c34734a8cf 100644 --- a/deepmd/tf/descriptor/se_r.py +++ b/deepmd/tf/descriptor/se_r.py @@ -102,7 +102,6 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - multi_task: bool = False, spin: Optional[Spin] = None, env_protection: float = 0.0, # not implement!! **kwargs, @@ -211,9 +210,6 @@ def __init__( self.sub_sess = tf.Session( graph=sub_graph, config=default_tf_session_config ) - self.multi_task = multi_task - if multi_task: - self.stat_dict = {"sumr": [], "sumn": [], "sumr2": []} def get_rcut(self): """Returns the cut-off radius.""" @@ -282,13 +278,8 @@ def compute_input_stats( sumr.append(sysr) sumn.append(sysn) sumr2.append(sysr2) - if not self.multi_task: - stat_dict = {"sumr": sumr, "sumn": sumn, "sumr2": sumr2} - self.merge_input_stats(stat_dict) - else: - self.stat_dict["sumr"] += sumr - self.stat_dict["sumn"] += sumn - self.stat_dict["sumr2"] += sumr2 + stat_dict = {"sumr": sumr, "sumn": sumn, "sumr2": sumr2} + self.merge_input_stats(stat_dict) def merge_input_stats(self, stat_dict): """Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd. diff --git a/deepmd/tf/descriptor/se_t.py b/deepmd/tf/descriptor/se_t.py index 77234cb92e..cd0a9c0a19 100644 --- a/deepmd/tf/descriptor/se_t.py +++ b/deepmd/tf/descriptor/se_t.py @@ -90,7 +90,6 @@ def __init__( activation_function: str = "tanh", precision: str = "default", uniform_seed: bool = False, - multi_task: bool = False, **kwargs, ) -> None: """Constructor.""" @@ -172,15 +171,6 @@ def __init__( sel_r=self.sel_r, ) self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config) - self.multi_task = multi_task - if multi_task: - self.stat_dict = { - "sumr": [], - "suma": [], - "sumn": [], - "sumr2": [], - "suma2": [], - } def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -256,21 +246,14 @@ def compute_input_stats( sumn.append(sysn) sumr2.append(sysr2) suma2.append(sysa2) - if not self.multi_task: - stat_dict = { - "sumr": sumr, - "suma": suma, - "sumn": sumn, - "sumr2": sumr2, - "suma2": suma2, - } - self.merge_input_stats(stat_dict) - else: - self.stat_dict["sumr"] += sumr - self.stat_dict["suma"] += suma - self.stat_dict["sumn"] += sumn - self.stat_dict["sumr2"] += sumr2 - self.stat_dict["suma2"] += suma2 + stat_dict = { + "sumr": sumr, + "suma": suma, + "sumn": sumn, + "sumr2": sumr2, + "suma2": suma2, + } + self.merge_input_stats(stat_dict) def merge_input_stats(self, stat_dict): """Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd. diff --git a/deepmd/tf/entrypoints/freeze.py b/deepmd/tf/entrypoints/freeze.py index a2b07f2c5e..787d26e9a4 100755 --- a/deepmd/tf/entrypoints/freeze.py +++ b/deepmd/tf/entrypoints/freeze.py @@ -7,7 +7,6 @@ https://blog.metaflow.fr/tensorflow-how-to-freeze-a-model-and-serve-it-with-a-python-api-d4f3596b3adc """ -import json import logging from os.path import ( abspath, @@ -27,7 +26,6 @@ import deepmd.tf.op # noqa: F401 from deepmd.tf.env import ( FITTING_NET_PATTERN, - REMOVE_SUFFIX_DICT, tf, ) from deepmd.tf.nvnmd.entrypoints.freeze import ( @@ -77,103 +75,6 @@ def _transfer_fitting_net_trainable_variables(sess, old_graph_def, raw_graph_def return old_graph_def -def _remove_fitting_net_suffix(output_graph_def, out_suffix): - """Remove fitting net suffix for multi-task mode. - - Parameters - ---------- - output_graph_def : tf.GraphDef - The output graph to remove suffix. - out_suffix : str - The suffix to remove. - """ - - def change_name(name, suffix): - if suffix in name: - for item in REMOVE_SUFFIX_DICT: - if item.format(suffix) in name: - name = name.replace(item.format(suffix), REMOVE_SUFFIX_DICT[item]) - break - assert suffix not in name, "fitting net name illegal!" - return name - - for node in output_graph_def.node: - if out_suffix in node.name: - node.name = change_name(node.name, out_suffix) - for idx in range(len(node.input)): - if out_suffix in node.input[idx]: - node.input[idx] = change_name(node.input[idx], out_suffix) - attr_list = node.attr["_class"].list.s - for idx in range(len(attr_list)): - if out_suffix in bytes.decode(attr_list[idx]): - attr_list[idx] = bytes( - change_name(bytes.decode(attr_list[idx]), out_suffix), - encoding="utf8", - ) - return output_graph_def - - -def _modify_model_suffix(output_graph_def, out_suffix, freeze_type): - """Modify model suffix in graph nodes for multi-task mode, including fitting net, model attr and training script. - - Parameters - ---------- - output_graph_def : tf.GraphDef - The output graph to remove suffix. - out_suffix : str - The suffix to remove. - freeze_type : str - The model type to freeze. - """ - output_graph_def = _remove_fitting_net_suffix(output_graph_def, out_suffix) - for node in output_graph_def.node: - if "model_attr/model_type" in node.name: - node.attr["value"].tensor.string_val[0] = bytes( - freeze_type, encoding="utf8" - ) - # change the input script for frozen model - elif "train_attr/training_script" in node.name: - jdata = json.loads(node.attr["value"].tensor.string_val[0]) - # fitting net - assert out_suffix in jdata["model"]["fitting_net_dict"] - jdata["model"]["fitting_net"] = jdata["model"].pop("fitting_net_dict")[ - out_suffix - ] - # data systems - systems = jdata["training"].pop("data_dict") - if out_suffix in systems: - jdata["training"]["training_data"] = systems[out_suffix][ - "training_data" - ] - if "validation_data" in systems[out_suffix]: - jdata["training"]["validation_data"] = systems[out_suffix][ - "validation_data" - ] - else: - jdata["training"]["training_data"] = {} - log.warning( - f"The fitting net {out_suffix} has no training data in input script, resulting in " - "untrained frozen model, and cannot be compressed directly! " - ) - # loss - if "loss_dict" in jdata: - loss_dict = jdata.pop("loss_dict") - if out_suffix in loss_dict: - jdata["loss"] = loss_dict[out_suffix] - # learning_rate - if "learning_rate_dict" in jdata: - learning_rate_dict = jdata.pop("learning_rate_dict") - if out_suffix in learning_rate_dict: - jdata["learning_rate"] = learning_rate_dict[out_suffix] - # fitting weight - if "fitting_weight" in jdata["training"]: - jdata["training"].pop("fitting_weight") - node.attr["value"].tensor.string_val[0] = bytes( - json.dumps(jdata), encoding="utf8" - ) - return output_graph_def - - def _make_node_names( model_type: str, modifier_type: Optional[str] = None, @@ -272,10 +173,6 @@ def _make_node_names( "model_attr/sel_type", "model_attr/output_dim", ] - elif model_type == "multi_task": - assert ( - node_names is not None - ), "node_names must be defined in multi-task united model! " else: raise RuntimeError(f"unknown model type {model_type}") if modifier_type == "dipole_charge": @@ -381,11 +278,6 @@ def freeze_graph( input_graph, # The graph_def is used to retrieve the nodes output_node, # The output node names are used to select the usefull nodes ) - # if multi-task, change fitting_net suffix and model_type - if out_suffix != "": - output_graph_def = _modify_model_suffix( - output_graph_def, out_suffix, freeze_type - ) # If we need to transfer the fitting net variables output_graph_def = _transfer_fitting_net_trainable_variables( @@ -398,89 +290,12 @@ def freeze_graph( log.info(f"{len(output_graph_def.node):d} ops in the final graph.") -def freeze_graph_multi( - sess, - input_graph, - input_node, - modifier, - out_graph_name, - node_names, - united_model: bool = False, -): - """Freeze multiple graphs for multi-task model. - - Parameters - ---------- - sess : tf.Session - The default session. - input_graph : tf.GraphDef - The input graph_def stored from the checkpoint. - input_node : List[str] - The expected nodes to freeze. - modifier : Optional[str], optional - Modifier type if any, by default None. - out_graph_name : str - The output graph. - node_names : Optional[str], optional - Names of nodes to output, by default None. - united_model : bool - If freeze all nodes into one unit model - """ - input_script = json.loads( - run_sess(sess, "train_attr/training_script:0", feed_dict={}) - ) - assert ( - "model" in input_script.keys() and "fitting_net_dict" in input_script["model"] - ) - if not united_model: - for fitting_key in input_script["model"]["fitting_net_dict"]: - fitting_type = input_script["model"]["fitting_net_dict"][fitting_key][ - "type" - ] - if out_graph_name[-3:] == ".pb": - output_graph_item = out_graph_name[:-3] + f"_{fitting_key}.pb" - else: - output_graph_item = out_graph_name + f"_{fitting_key}" - freeze_graph( - sess, - input_graph, - input_node, - fitting_type, - modifier, - output_graph_item, - node_names, - out_suffix=fitting_key, - ) - else: - node_multi = [] - for fitting_key in input_script["model"]["fitting_net_dict"]: - fitting_type = input_script["model"]["fitting_net_dict"][fitting_key][ - "type" - ] - node_multi += _make_node_names( - fitting_type, modifier, out_suffix=fitting_key - ) - node_multi = list(set(node_multi)) - if node_names is not None: - node_multi = node_names - freeze_graph( - sess, - input_graph, - input_node, - "multi_task", - modifier, - out_graph_name, - node_multi, - ) - - def freeze( *, checkpoint_folder: str, output: str, node_names: Optional[str] = None, nvnmd_weight: Optional[str] = None, - united_model: bool = False, **kwargs, ): """Freeze the graph in supplied folder. @@ -495,8 +310,6 @@ def freeze( names of nodes to output, by default None nvnmd_weight : Optional[str], optional nvnmd weight file - united_model : bool - when in multi-task mode, freeze all nodes into one unit model **kwargs other arguments """ @@ -560,23 +373,12 @@ def freeze( modifier_type = None if nvnmd_weight is not None: save_weight(sess, nvnmd_weight) # nvnmd - if model_type != "multi_task": - freeze_graph( - sess, - input_graph_def, - nodes, - model_type, - modifier_type, - output_graph, - node_names, - ) - else: - freeze_graph_multi( - sess, - input_graph_def, - nodes, - modifier_type, - output_graph, - node_names, - united_model=united_model, - ) + freeze_graph( + sess, + input_graph_def, + nodes, + model_type, + modifier_type, + output_graph, + node_names, + ) diff --git a/deepmd/tf/entrypoints/train.py b/deepmd/tf/entrypoints/train.py index e573423fc3..2fef038f7d 100755 --- a/deepmd/tf/entrypoints/train.py +++ b/deepmd/tf/entrypoints/train.py @@ -43,9 +43,6 @@ from deepmd.tf.utils.finetune import ( replace_model_params_with_pretrained_model, ) -from deepmd.tf.utils.multi_init import ( - replace_model_params_with_frz_multi_model, -) from deepmd.utils.data_system import ( get_data, ) @@ -126,9 +123,6 @@ def train( jdata, run_opt.finetune ) - if "fitting_net_dict" in jdata["model"] and run_opt.init_frz_model is not None: - jdata = replace_model_params_with_frz_multi_model(jdata, run_opt.init_frz_model) - jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") jdata = normalize(jdata) @@ -193,62 +187,23 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal # setup data modifier modifier = get_modifier(jdata["model"].get("modifier", None)) - # check the multi-task mode - multi_task_mode = "fitting_net_dict" in jdata["model"] - # decouple the training data from the model compress process train_data = None valid_data = None if not is_compress: # init data - if not multi_task_mode: - train_data = get_data( - jdata["training"]["training_data"], rcut, ipt_type_map, modifier + train_data = get_data( + jdata["training"]["training_data"], rcut, ipt_type_map, modifier + ) + train_data.print_summary("training") + if jdata["training"].get("validation_data", None) is not None: + valid_data = get_data( + jdata["training"]["validation_data"], + rcut, + train_data.type_map, + modifier, ) - train_data.print_summary("training") - if jdata["training"].get("validation_data", None) is not None: - valid_data = get_data( - jdata["training"]["validation_data"], - rcut, - train_data.type_map, - modifier, - ) - valid_data.print_summary("validation") - else: - train_data = {} - valid_data = {} - for data_systems in jdata["training"]["data_dict"]: - if ( - jdata["training"]["fitting_weight"][data_systems] > 0.0 - ): # check only the available pair - train_data[data_systems] = get_data( - jdata["training"]["data_dict"][data_systems]["training_data"], - rcut, - ipt_type_map, - modifier, - multi_task_mode, - ) - train_data[data_systems].print_summary( - f"training in {data_systems}" - ) - if ( - jdata["training"]["data_dict"][data_systems].get( - "validation_data", None - ) - is not None - ): - valid_data[data_systems] = get_data( - jdata["training"]["data_dict"][data_systems][ - "validation_data" - ], - rcut, - train_data[data_systems].type_map, - modifier, - multi_task_mode, - ) - valid_data[data_systems].print_summary( - f"validation in {data_systems}" - ) + valid_data.print_summary("validation") else: if modifier is not None: modifier.build_fv_graph() diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 4c20a2b978..cdb4feadc3 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -111,7 +111,6 @@ def dlopen_library(module: str, filename: str): "EMBEDDING_NET_PATTERN", "TYPE_EMBEDDING_PATTERN", "ATTENTION_LAYER_PATTERN", - "REMOVE_SUFFIX_DICT", "TF_VERSION", "tf_py_version", ] @@ -209,31 +208,6 @@ def dlopen_library(module: str, filename: str): ) ) -REMOVE_SUFFIX_DICT = { - "model_attr/sel_type_{}": "model_attr/sel_type", - "model_attr/output_dim_{}": "model_attr/output_dim", - "_{}/": "/", - # when atom_ener is set - "_{}_1/": "_1/", - "o_energy_{}": "o_energy", - "o_force_{}": "o_force", - "o_virial_{}": "o_virial", - "o_atom_energy_{}": "o_atom_energy", - "o_atom_virial_{}": "o_atom_virial", - "o_dipole_{}": "o_dipole", - "o_global_dipole_{}": "o_global_dipole", - "o_polar_{}": "o_polar", - "o_global_polar_{}": "o_global_polar", - "o_rmat_{}": "o_rmat", - "o_rmat_deriv_{}": "o_rmat_deriv", - "o_nlist_{}": "o_nlist", - "o_rij_{}": "o_rij", - "o_dm_force_{}": "o_dm_force", - "o_dm_virial_{}": "o_dm_virial", - "o_dm_av_{}": "o_dm_av", - "o_wfc_{}": "o_wfc", -} - def set_mkl(): """Tuning MKL for the best performance. diff --git a/deepmd/tf/model/__init__.py b/deepmd/tf/model/__init__.py index 1d100f2b09..85cc74781d 100644 --- a/deepmd/tf/model/__init__.py +++ b/deepmd/tf/model/__init__.py @@ -18,9 +18,6 @@ from .ener import ( EnerModel, ) -from .multi import ( - MultiModel, -) from .tensor import ( DipoleModel, GlobalPolarModel, @@ -31,7 +28,6 @@ __all__ = [ "EnerModel", "DOSModel", - "MultiModel", "DipoleModel", "GlobalPolarModel", "PolarModel", diff --git a/deepmd/tf/model/multi.py b/deepmd/tf/model/multi.py deleted file mode 100644 index e49ad47ee3..0000000000 --- a/deepmd/tf/model/multi.py +++ /dev/null @@ -1,677 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import json -from typing import ( - Dict, - List, - Optional, -) - -import numpy as np - -from deepmd.tf.descriptor.descriptor import ( - Descriptor, -) -from deepmd.tf.env import ( - MODEL_VERSION, - global_cvt_2_ener_float, - op_module, - tf, -) -from deepmd.tf.fit import ( - DipoleFittingSeA, - DOSFitting, - EnerFitting, - GlobalPolarFittingSeA, - PolarFittingSeA, -) -from deepmd.tf.fit.fitting import ( - Fitting, -) -from deepmd.tf.loss.loss import ( - Loss, -) -from deepmd.tf.utils.argcheck import ( - type_embedding_args, -) -from deepmd.tf.utils.graph import ( - get_tensor_by_name_from_graph, -) -from deepmd.tf.utils.pair_tab import ( - PairTab, -) -from deepmd.tf.utils.spin import ( - Spin, -) -from deepmd.tf.utils.type_embed import ( - TypeEmbedNet, -) - -from .model import ( - Model, -) -from .model_stat import ( - make_stat_input, - merge_sys_stat, -) - - -@Model.register("multi") -class MultiModel(Model): - """Multi-task model. - - Parameters - ---------- - descriptor - Descriptor - fitting_net_dict - Dictionary of fitting nets - fitting_type_dict - deprecated argument - type_embedding - Type embedding net - type_map - Mapping atom type to the name (str) of the type. - For example `type_map[1]` gives the name of the type 1. - data_stat_nbatch - Number of frames used for data statistic - data_stat_protect - Protect parameter for atomic energy regression - use_srtab - The table for the short-range pairwise interaction added on top of DP. The table is a text data file with (N_t + 1) * N_t / 2 + 1 columes. The first colume is the distance between atoms. The second to the last columes are energies for pairs of certain types. For example we have two atom types, 0 and 1. The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. - smin_alpha - The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided. - sw_rmin - The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. - sw_rmin - The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided. - """ - - model_type = "multi_task" - - def __init__( - self, - descriptor: dict, - fitting_net_dict: dict, - fitting_type_dict: Optional[dict] = None, # deprecated - type_embedding=None, - type_map: Optional[List[str]] = None, - data_stat_nbatch: int = 10, - data_stat_protect: float = 1e-2, - use_srtab: Optional[str] = None, # all the ener fitting will do this - smin_alpha: Optional[float] = None, - sw_rmin: Optional[float] = None, - sw_rmax: Optional[float] = None, - **kwargs, - ) -> None: - """Constructor.""" - super().__init__( - descriptor=descriptor, - fitting_net_dict=fitting_net_dict, - type_embedding=type_embedding, - type_map=type_map, - data_stat_nbatch=data_stat_nbatch, - data_stat_protect=data_stat_protect, - use_srtab=use_srtab, - smin_alpha=smin_alpha, - sw_rmin=sw_rmin, - sw_rmax=sw_rmax, - ) - if self.spin is not None and not isinstance(self.spin, Spin): - self.spin = Spin(**self.spin) - if isinstance(descriptor, Descriptor): - self.descrpt = descriptor - else: - self.descrpt = Descriptor( - **descriptor, - ntypes=len(self.get_type_map()), - multi_task=True, - spin=self.spin, - ) - - fitting_dict = {} - for item in fitting_net_dict: - item_fitting_param = fitting_net_dict[item] - if isinstance(item_fitting_param, Fitting): - fitting_dict[item] = item_fitting_param - else: - if item_fitting_param["type"] in ["dipole", "polar"]: - item_fitting_param["embedding_width"] = ( - self.descrpt.get_dim_rot_mat_1() - ) - fitting_dict[item] = Fitting( - **item_fitting_param, - descrpt=self.descrpt, - spin=self.spin, - ntypes=self.descrpt.get_ntypes(), - dim_descrpt=self.descrpt.get_dim_out(), - ) - - self.ntypes = self.descrpt.get_ntypes() - # type embedding - if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet): - self.typeebd = type_embedding - elif type_embedding is not None: - self.typeebd = TypeEmbedNet( - ntypes=self.ntypes, - **type_embedding, - padding=self.descrpt.explicit_ntypes, - ) - elif self.descrpt.explicit_ntypes: - default_args = type_embedding_args() - default_args_dict = {i.name: i.default for i in default_args} - default_args_dict["activation_function"] = None - self.typeebd = TypeEmbedNet( - ntypes=self.ntypes, - **default_args_dict, - padding=True, - ) - else: - self.typeebd = None - - # descriptor - self.rcut = self.descrpt.get_rcut() - # fitting - self.fitting_dict = fitting_dict - self.numb_fparam_dict = { - item: self.fitting_dict[item].get_numb_fparam() - for item in self.fitting_dict - if isinstance(self.fitting_dict[item], EnerFitting) - } - # other inputs - if type_map is None: - self.type_map = [] - else: - self.type_map = type_map - self.data_stat_nbatch = data_stat_nbatch - self.data_stat_protect = data_stat_protect - self.srtab_name = use_srtab - if self.srtab_name is not None: - self.srtab = PairTab(self.srtab_name) - self.smin_alpha = smin_alpha - self.sw_rmin = sw_rmin - self.sw_rmax = sw_rmax - else: - self.srtab = None - - def get_rcut(self): - return self.rcut - - def get_ntypes(self): - return self.ntypes - - def get_type_map(self): - return self.type_map - - def data_stat(self, data): - for fitting_key in data: - all_stat = make_stat_input( - data[fitting_key], self.data_stat_nbatch, merge_sys=False - ) - m_all_stat = merge_sys_stat(all_stat) - self._compute_input_stat( - m_all_stat, - protection=self.data_stat_protect, - mixed_type=data[fitting_key].mixed_type, - fitting_key=fitting_key, - ) - self._compute_output_stat( - all_stat, - mixed_type=data[fitting_key].mixed_type, - fitting_key=fitting_key, - ) - self.descrpt.merge_input_stats(self.descrpt.stat_dict) - - def _compute_input_stat( - self, all_stat, protection=1e-2, mixed_type=False, fitting_key="" - ): - if mixed_type: - self.descrpt.compute_input_stats( - all_stat["coord"], - all_stat["box"], - all_stat["type"], - all_stat["natoms_vec"], - all_stat["default_mesh"], - all_stat, - mixed_type, - all_stat["real_natoms_vec"], - ) - else: - self.descrpt.compute_input_stats( - all_stat["coord"], - all_stat["box"], - all_stat["type"], - all_stat["natoms_vec"], - all_stat["default_mesh"], - all_stat, - ) - if hasattr(self.fitting_dict[fitting_key], "compute_input_stats"): - self.fitting_dict[fitting_key].compute_input_stats( - all_stat, protection=protection - ) - - def _compute_output_stat(self, all_stat, mixed_type=False, fitting_key=""): - if hasattr(self.fitting_dict[fitting_key], "compute_output_stats"): - if mixed_type: - self.fitting_dict[fitting_key].compute_output_stats( - all_stat, mixed_type=mixed_type - ) - else: - self.fitting_dict[fitting_key].compute_output_stats(all_stat) - - def build( - self, - coord_, - atype_, - natoms, - box, - mesh, - input_dict, - frz_model=None, - ckpt_meta: Optional[str] = None, - suffix="", - reuse=None, - ): - if input_dict is None: - input_dict = {} - with tf.variable_scope("model_attr" + suffix, reuse=reuse): - t_tmap = tf.constant(" ".join(self.type_map), name="tmap", dtype=tf.string) - t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string) - t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string) - t_st = {} - t_od = {} - sel_type = {} - natomsel = {} - nout = {} - for fitting_key in self.fitting_dict: - if isinstance( - self.fitting_dict[fitting_key], - (DipoleFittingSeA, PolarFittingSeA, GlobalPolarFittingSeA), - ): - sel_type[fitting_key] = self.fitting_dict[ - fitting_key - ].get_sel_type() - natomsel[fitting_key] = sum( - natoms[2 + type_i] for type_i in sel_type[fitting_key] - ) - nout[fitting_key] = self.fitting_dict[fitting_key].get_out_size() - t_st[fitting_key] = tf.constant( - sel_type[fitting_key], - name=f"sel_type_{fitting_key}", - dtype=tf.int32, - ) - t_od[fitting_key] = tf.constant( - nout[fitting_key], - name=f"output_dim_{fitting_key}", - dtype=tf.int32, - ) - - if self.srtab is not None: - tab_info, tab_data = self.srtab.get() - self.tab_info = tf.get_variable( - "t_tab_info", - tab_info.shape, - dtype=tf.float64, - trainable=False, - initializer=tf.constant_initializer(tab_info, dtype=tf.float64), - ) - self.tab_data = tf.get_variable( - "t_tab_data", - tab_data.shape, - dtype=tf.float64, - trainable=False, - initializer=tf.constant_initializer(tab_data, dtype=tf.float64), - ) - - coord = tf.reshape(coord_, [-1, natoms[1] * 3]) - atype = tf.reshape(atype_, [-1, natoms[1]]) - input_dict["nframes"] = tf.shape(coord)[0] - - # type embedding if any - if self.typeebd is not None: - type_embedding = self.build_type_embedding( - self.ntypes, - reuse=reuse, - suffix=suffix, - frz_model=frz_model, - ckpt_meta=ckpt_meta, - ) - input_dict["type_embedding"] = type_embedding - input_dict["atype"] = atype_ - - dout = self.build_descrpt( - coord, - atype, - natoms, - box, - mesh, - input_dict, - frz_model=frz_model, - ckpt_meta=ckpt_meta, - suffix=suffix, - reuse=reuse, - ) - dout = tf.identity(dout, name="o_descriptor") - - if self.srtab is not None: - nlist, rij, sel_a, sel_r = self.descrpt.get_nlist() - nnei_a = np.cumsum(sel_a)[-1] - nnei_r = np.cumsum(sel_r)[-1] - sw_lambda, sw_deriv = op_module.soft_min_switch( - atype, - rij, - nlist, - natoms, - sel_a=sel_a, - sel_r=sel_r, - alpha=self.smin_alpha, - rmin=self.sw_rmin, - rmax=self.sw_rmax, - ) - inv_sw_lambda = 1.0 - sw_lambda - # NOTICE: - # atom energy is not scaled, - # force and virial are scaled - tab_atom_ener, tab_force, tab_atom_virial = op_module.pair_tab( - self.tab_info, - self.tab_data, - atype, - rij, - nlist, - natoms, - sw_lambda, - sel_a=sel_a, - sel_r=sel_r, - ) - - rot_mat = self.descrpt.get_rot_mat() - rot_mat = tf.identity(rot_mat, name="o_rot_mat" + suffix) - self.atom_ener = {} - model_dict = {} - for fitting_key in self.fitting_dict: - if isinstance(self.fitting_dict[fitting_key], EnerFitting): - atom_ener = self.fitting_dict[fitting_key].build( - dout, - natoms, - input_dict, - reuse=reuse, - suffix=f"_{fitting_key}" + suffix, - ) - self.atom_ener[fitting_key] = atom_ener - if self.srtab is not None: - energy_diff = tab_atom_ener - tf.reshape(atom_ener, [-1, natoms[0]]) - tab_atom_ener = tf.reshape(sw_lambda, [-1]) * tf.reshape( - tab_atom_ener, [-1] - ) - atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener - energy_raw = tab_atom_ener + atom_ener - else: - energy_raw = atom_ener - energy_raw = tf.reshape( - energy_raw, - [-1, natoms[0]], - name=f"o_atom_energy_{fitting_key}" + suffix, - ) - energy = tf.reduce_sum( - global_cvt_2_ener_float(energy_raw), - axis=1, - name=f"o_energy_{fitting_key}" + suffix, - ) - force, virial, atom_virial = self.descrpt.prod_force_virial( - atom_ener, natoms - ) - - if self.srtab is not None: - sw_force = op_module.soft_min_force( - energy_diff, - sw_deriv, - nlist, - natoms, - n_a_sel=nnei_a, - n_r_sel=nnei_r, - ) - force = force + sw_force + tab_force - - force = tf.reshape( - force, - [-1, 3 * natoms[1]], - name=f"o_force_{fitting_key}" + suffix, - ) - - if self.srtab is not None: - sw_virial, sw_atom_virial = op_module.soft_min_virial( - energy_diff, - sw_deriv, - rij, - nlist, - natoms, - n_a_sel=nnei_a, - n_r_sel=nnei_r, - ) - atom_virial = atom_virial + sw_atom_virial + tab_atom_virial - virial = ( - virial - + sw_virial - + tf.reduce_sum( - tf.reshape(tab_atom_virial, [-1, natoms[1], 9]), axis=1 - ) - ) - - virial = tf.reshape( - virial, [-1, 9], name=f"o_virial_{fitting_key}" + suffix - ) - atom_virial = tf.reshape( - atom_virial, - [-1, 9 * natoms[1]], - name=f"o_atom_virial_{fitting_key}" + suffix, - ) - - model_dict[fitting_key] = {} - model_dict[fitting_key]["energy"] = energy - model_dict[fitting_key]["force"] = force - model_dict[fitting_key]["virial"] = virial - model_dict[fitting_key]["atom_ener"] = energy_raw - model_dict[fitting_key]["atom_virial"] = atom_virial - model_dict[fitting_key]["coord"] = coord - model_dict[fitting_key]["atype"] = atype - elif isinstance( - self.fitting_dict[fitting_key], - (DipoleFittingSeA, PolarFittingSeA, GlobalPolarFittingSeA), - ): - tensor_name = { - DipoleFittingSeA: "dipole", - PolarFittingSeA: "polar", - GlobalPolarFittingSeA: "global_polar", - }[type(self.fitting_dict[fitting_key])] - output = self.fitting_dict[fitting_key].build( - dout, - rot_mat, - natoms, - input_dict, - reuse=reuse, - suffix=f"_{fitting_key}" + suffix, - ) - framesize = ( - nout - if "global" in tensor_name - else natomsel[fitting_key] * nout[fitting_key] - ) - output = tf.reshape( - output, - [-1, framesize], - name=f"o_{tensor_name}_{fitting_key}" + suffix, - ) - - model_dict[fitting_key] = {} - model_dict[fitting_key][tensor_name] = output - - if "global" not in tensor_name: - gname = "global_" + tensor_name - atom_out = tf.reshape( - output, [-1, natomsel[fitting_key], nout[fitting_key]] - ) - global_out = tf.reduce_sum(atom_out, axis=1) - global_out = tf.reshape( - global_out, - [-1, nout[fitting_key]], - name=f"o_{gname}_{fitting_key}" + suffix, - ) - - out_cpnts = tf.split(atom_out, nout[fitting_key], axis=-1) - force_cpnts = [] - virial_cpnts = [] - atom_virial_cpnts = [] - - for out_i in out_cpnts: - ( - force_i, - virial_i, - atom_virial_i, - ) = self.descrpt.prod_force_virial(out_i, natoms) - force_cpnts.append(tf.reshape(force_i, [-1, 3 * natoms[1]])) - virial_cpnts.append(tf.reshape(virial_i, [-1, 9])) - atom_virial_cpnts.append( - tf.reshape(atom_virial_i, [-1, 9 * natoms[1]]) - ) - - # [nframe x nout x (natom x 3)] - force = tf.concat( - force_cpnts, - axis=1, - name=f"o_force_{fitting_key}" + suffix, - ) - # [nframe x nout x 9] - virial = tf.concat( - virial_cpnts, - axis=1, - name=f"o_virial_{fitting_key}" + suffix, - ) - # [nframe x nout x (natom x 9)] - atom_virial = tf.concat( - atom_virial_cpnts, - axis=1, - name=f"o_atom_virial_{fitting_key}" + suffix, - ) - - model_dict[fitting_key][gname] = global_out - model_dict[fitting_key]["force"] = force - model_dict[fitting_key]["virial"] = virial - model_dict[fitting_key]["atom_virial"] = atom_virial - return model_dict - - def init_variables( - self, - graph: tf.Graph, - graph_def: tf.GraphDef, - model_type: str = "original_model", - suffix: str = "", - ) -> None: - """Init the embedding net variables with the given frozen model. - - Parameters - ---------- - graph : tf.Graph - The input frozen model graph - graph_def : tf.GraphDef - The input frozen model graph_def - model_type : str - the type of the model - suffix : str - suffix to name scope - """ - # self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch... - # initialize fitting net with the given compressed frozen model - assert ( - model_type == "original_model" - ), "Initialization in multi-task mode does not support compressed model!" - self.descrpt.init_variables(graph, graph_def, suffix=suffix) - old_jdata = json.loads( - get_tensor_by_name_from_graph(graph, "train_attr/training_script") - ) - old_fitting_keys = list(old_jdata["model"]["fitting_net_dict"].keys()) - newly_added_fittings = set(self.fitting_dict.keys()) - set(old_fitting_keys) - reused_fittings = set(self.fitting_dict.keys()) - newly_added_fittings - for fitting_key in reused_fittings: - self.fitting_dict[fitting_key].init_variables( - graph, graph_def, suffix=f"_{fitting_key}" + suffix - ) - tf.constant("original_model", name="model_type", dtype=tf.string) - if self.typeebd is not None: - self.typeebd.init_variables(graph, graph_def, suffix=suffix) - - def enable_mixed_precision(self, mixed_prec: dict): - """Enable mixed precision for the model. - - Parameters - ---------- - mixed_prec : dict - The mixed precision config - """ - self.descrpt.enable_mixed_precision(mixed_prec) - for fitting_key in self.fitting_dict: - self.fitting_dict[fitting_key].enable_mixed_precision(self.mixed_prec) - - def get_numb_fparam(self) -> dict: - """Get the number of frame parameters.""" - numb_fparam_dict = {} - for fitting_key in self.fitting_dict: - if isinstance(self.fitting_dict[fitting_key], (EnerFitting, DOSFitting)): - numb_fparam_dict[fitting_key] = self.fitting_dict[ - fitting_key - ].get_numb_fparam() - else: - numb_fparam_dict[fitting_key] = 0 - return numb_fparam_dict - - def get_numb_aparam(self) -> dict: - """Get the number of atomic parameters.""" - numb_aparam_dict = {} - for fitting_key in self.fitting_dict: - if isinstance(self.fitting_dict[fitting_key], (EnerFitting, DOSFitting)): - numb_aparam_dict[fitting_key] = self.fitting_dict[ - fitting_key - ].get_numb_aparam() - else: - numb_aparam_dict[fitting_key] = 0 - return numb_aparam_dict - - def get_numb_dos(self) -> dict: - """Get the number of gridpoints in energy space.""" - numb_dos_dict = {} - for fitting_key in self.fitting_dict: - if isinstance(self.fitting_dict[fitting_key], DOSFitting): - numb_dos_dict[fitting_key] = self.fitting_dict[ - fitting_key - ].get_numb_dos() - else: - numb_dos_dict[fitting_key] = 0 - return numb_dos_dict - - def get_fitting(self) -> dict: - """Get the fitting(s).""" - return self.fitting_dict.copy() - - def get_loss(self, loss: dict, lr: dict) -> Dict[str, Loss]: - loss_dict = {} - for fitting_key in self.fitting_dict: - loss_param = loss.get(fitting_key, {}) - loss_dict[fitting_key] = self.fitting_dict[fitting_key].get_loss( - loss_param, lr[fitting_key] - ) - return loss_dict - - @classmethod - def update_sel(cls, global_jdata: dict, local_jdata: dict): - """Update the selection and perform neighbor statistics. - - Parameters - ---------- - global_jdata : dict - The global data, containing the training section - local_jdata : dict - The local data refer to the current class - """ - local_jdata_cpy = local_jdata.copy() - local_jdata_cpy["descriptor"] = Descriptor.update_sel( - global_jdata, local_jdata["descriptor"] - ) - return local_jdata_cpy diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index dc9f81957a..855b2ee722 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -43,13 +43,9 @@ from deepmd.tf.fit.ener import ( EnerFitting, ) -from deepmd.tf.model import ( - MultiModel, -) from deepmd.tf.model.model import ( Model, ) -from deepmd.tf.utils import random as dp_random from deepmd.tf.utils.data_system import ( DeepmdDataSystem, ) @@ -94,8 +90,6 @@ def __init__(self, jdata, run_opt, is_compress=False): def _init_param(self, jdata): # model config model_param = j_must_have(jdata, "model") - if "fitting_key" in model_param: - model_param["type"] = "multi" # nvnmd self.nvnmd_param = jdata.get("nvnmd", {}) @@ -107,7 +101,6 @@ def _init_param(self, jdata): # init model self.model = Model(**model_param) - self.multi_task_mode = isinstance(self.model, MultiModel) self.fitting = self.model.get_fitting() def get_lr_and_coef(lr_param): @@ -128,38 +121,15 @@ def get_lr_and_coef(lr_param): return lr, scale_lr_coef # learning rate - if not self.multi_task_mode: - lr_param = j_must_have(jdata, "learning_rate") - self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param) - else: - self.lr_dict = {} - self.scale_lr_coef_dict = {} - lr_param_dict = jdata.get("learning_rate_dict", {}) - for fitting_key in self.fitting: - lr_param = lr_param_dict.get(fitting_key, {}) - ( - self.lr_dict[fitting_key], - self.scale_lr_coef_dict[fitting_key], - ) = get_lr_and_coef(lr_param) + lr_param = j_must_have(jdata, "learning_rate") + self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param) # loss # infer loss type by fitting_type - if not self.multi_task_mode: - loss_param = jdata.get("loss", {}) - self.loss = self.model.get_loss(loss_param, self.lr) - else: - loss_param = jdata.get("loss_dict", {}) - self.loss_dict = self.model.get_loss(loss_param, self.lr_dict) + loss_param = jdata.get("loss", {}) + self.loss = self.model.get_loss(loss_param, self.lr) # training tr_data = jdata["training"] - self.fitting_weight = tr_data.get("fitting_weight", None) - if self.multi_task_mode: - self.fitting_key_list = [] - self.fitting_prob = [] - for fitting_key in self.fitting: - self.fitting_key_list.append(fitting_key) - # multi-task mode must have self.fitting_weight - self.fitting_prob.append(self.fitting_weight[fitting_key]) self.disp_file = tr_data.get("disp_file", "lcurve.out") self.disp_freq = tr_data.get("disp_freq", 1000) self.save_freq = tr_data.get("save_freq", 1000) @@ -188,24 +158,12 @@ def get_lr_and_coef(lr_param): # self.sys_probs = tr_data['sys_probs'] # self.auto_prob_style = tr_data['auto_prob'] self.useBN = False - if not self.multi_task_mode: - self.numb_fparam = self.model.get_numb_fparam() + self.numb_fparam = self.model.get_numb_fparam() - if tr_data.get("validation_data", None) is not None: - self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) - else: - self.valid_numb_batch = 1 + if tr_data.get("validation_data", None) is not None: + self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) else: - self.numb_fparam_dict = self.model.get_numb_fparam() - self.valid_numb_batch_dict = {} - data_dict = tr_data.get("data_dict", None) - for systems in data_dict: - if data_dict[systems].get("validation_data", None) is not None: - self.valid_numb_batch_dict[systems] = data_dict[systems][ - "validation_data" - ].get("numb_btch", 1) - else: - self.valid_numb_batch_dict[systems] = 1 + self.valid_numb_batch = 1 # if init the graph with the frozen model self.frz_model = None @@ -216,45 +174,21 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): self.ntypes = self.model.get_ntypes() self.stop_batch = stop_batch - if not self.multi_task_mode: - if not self.is_compress and data.mixed_type: - assert isinstance( - self.fitting, EnerFitting - ), "Data in mixed_type format must use ener fitting!" + if not self.is_compress and data.mixed_type: + assert isinstance( + self.fitting, EnerFitting + ), "Data in mixed_type format must use ener fitting!" - if self.numb_fparam > 0: - log.info("training with %d frame parameter(s)" % self.numb_fparam) - else: - log.info("training without frame parameter") + if self.numb_fparam > 0: + log.info("training with %d frame parameter(s)" % self.numb_fparam) else: - assert ( - not self.is_compress - ), "You should not reach here, multi-task input could not be compressed! " - self.valid_fitting_key = [] - for fitting_key in data: - self.valid_fitting_key.append(fitting_key) - if data[fitting_key].mixed_type: - assert isinstance( - self.fitting[fitting_key], EnerFitting - ), f"Data for fitting net {fitting_key} in mixed_type format must use ener fitting!" - if self.numb_fparam_dict[fitting_key] > 0: - log.info( - "fitting net %s training with %d frame parameter(s)" - % (fitting_key, self.numb_fparam_dict[fitting_key]) - ) - else: - log.info( - f"fitting net {fitting_key} training without frame parameter" - ) + log.info("training without frame parameter") if not self.is_compress: # Usually, the type number of the model should be equal to that of the data # However, nt_model > nt_data should be allowed, since users may only want to # train using a dataset that only have some of elements - if not self.multi_task_mode: - single_data = data - else: - single_data = data[next(iter(data.keys()))] + single_data = data if self.ntypes < single_data.get_ntypes(): raise ValueError( "The number of types of the training data is %d, but that of the " @@ -266,12 +200,7 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): % (single_data.get_ntypes(), self.ntypes) ) self.type_map = single_data.get_type_map() - if not self.multi_task_mode: - self.batch_size = data.get_batch_size() - else: - self.batch_size = {} - for fitting_key in data: - self.batch_size[fitting_key] = data[fitting_key].get_batch_size() + self.batch_size = data.get_batch_size() if self.run_opt.init_mode not in ( "init_from_model", "restart", @@ -314,52 +243,23 @@ def build(self, data=None, stop_batch=0, origin_type_map=None, suffix=""): def _build_lr(self): self._extra_train_ops = [] self.global_step = tf.train.get_or_create_global_step() - if not self.multi_task_mode: - self.learning_rate = self.lr.build(self.global_step, self.stop_batch) - else: - self.learning_rate_dict = {} - - for fitting_key in self.fitting: - self.learning_rate_dict[fitting_key] = self.lr_dict[fitting_key].build( - self.global_step, self.stop_batch - ) - + self.learning_rate = self.lr.build(self.global_step, self.stop_batch) log.info("built lr") def _build_loss(self): if self.stop_batch == 0: # l2 is not used if stop_batch is zero return None, None - if not self.multi_task_mode: - l2_l, l2_more = self.loss.build( - self.learning_rate, - self.place_holders["natoms_vec"], - self.model_pred, - self.place_holders, - suffix="test", - ) - - if self.mixed_prec is not None: - l2_l = tf.cast(l2_l, get_precision(self.mixed_prec["output_prec"])) - else: - l2_l, l2_more = {}, {} - for fitting_key in self.fitting: - lr = self.learning_rate_dict[fitting_key] - model = self.model_pred[fitting_key] - loss_dict = self.loss_dict[fitting_key] - - l2_l[fitting_key], l2_more[fitting_key] = loss_dict.build( - lr, - self.place_holders["natoms_vec"], - model, - self.place_holders, - suffix=fitting_key, - ) + l2_l, l2_more = self.loss.build( + self.learning_rate, + self.place_holders["natoms_vec"], + self.model_pred, + self.place_holders, + suffix="test", + ) - if self.mixed_prec is not None: - l2_l[fitting_key] = tf.cast( - l2_l[fitting_key], get_precision(self.mixed_prec["output_prec"]) - ) + if self.mixed_prec is not None: + l2_l = tf.cast(l2_l, get_precision(self.mixed_prec["output_prec"])) return l2_l, l2_more @@ -372,10 +272,7 @@ def _build_network(self, data, suffix=""): ) self._get_place_holders(data_requirement) else: - if not self.multi_task_mode: - self._get_place_holders(data.get_data_dict()) - else: - self._get_place_holders(data[next(iter(data.keys()))].get_data_dict()) + self._get_place_holders(data.get_data_dict()) self.place_holders["type"] = tf.placeholder(tf.int32, [None], name="t_type") self.place_holders["natoms_vec"] = tf.placeholder( @@ -402,39 +299,18 @@ def _build_network(self, data, suffix=""): log.info("built network") - def _build_optimizer(self, fitting_key=None): + def _build_optimizer(self): if self.run_opt.is_distrib: - if fitting_key is None: - if self.scale_lr_coef > 1.0: - log.info("Scale learning rate by coef: %f", self.scale_lr_coef) - optimizer = tf.train.AdamOptimizer( - self.learning_rate * self.scale_lr_coef - ) - else: - optimizer = tf.train.AdamOptimizer(self.learning_rate) - optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) - else: - if self.scale_lr_coef_dict[fitting_key] > 1.0: - log.info( - "Scale learning rate by coef: %f", - self.scale_lr_coef_dict[fitting_key], - ) - optimizer = tf.train.AdamOptimizer( - self.learning_rate_dict[fitting_key] - * self.scale_lr_coef_dict[fitting_key] - ) - else: - optimizer = tf.train.AdamOptimizer( - learning_rate=self.learning_rate_dict[fitting_key] - ) - optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) - else: - if fitting_key is None: - optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) - else: + if self.scale_lr_coef > 1.0: + log.info("Scale learning rate by coef: %f", self.scale_lr_coef) optimizer = tf.train.AdamOptimizer( - learning_rate=self.learning_rate_dict[fitting_key] + self.learning_rate * self.scale_lr_coef ) + else: + optimizer = tf.train.AdamOptimizer(self.learning_rate) + optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) + else: + optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) if self.mixed_prec is not None: _TF_VERSION = Version(TF_VERSION) @@ -460,28 +336,15 @@ def _build_training(self): trainable_variables = tf.trainable_variables() - if not self.multi_task_mode: - optimizer = self._build_optimizer() - apply_op = optimizer.minimize( - loss=self.l2_l, - global_step=self.global_step, - var_list=trainable_variables, - name="train_step", - ) - train_ops = [apply_op, *self._extra_train_ops] - self.train_op = tf.group(*train_ops) - else: - self.train_op = {} - for fitting_key in self.fitting: - optimizer = self._build_optimizer(fitting_key=fitting_key) - apply_op = optimizer.minimize( - loss=self.l2_l[fitting_key], - global_step=self.global_step, - var_list=trainable_variables, - name=f"train_step_{fitting_key}", - ) - train_ops = [apply_op, *self._extra_train_ops] - self.train_op[fitting_key] = tf.group(*train_ops) + optimizer = self._build_optimizer() + apply_op = optimizer.minimize( + loss=self.l2_l, + global_step=self.global_step, + var_list=trainable_variables, + name="train_step", + ) + train_ops = [apply_op, *self._extra_train_ops] + self.train_op = tf.group(*train_ops) log.info("built training") def _init_session(self): @@ -555,30 +418,16 @@ def train(self, train_data=None, valid_data=None): cur_batch = run_sess(self.sess, self.global_step) is_first_step = True self.cur_batch = cur_batch - if not self.multi_task_mode: - log.info( - "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" - % ( - run_sess(self.sess, self.learning_rate), - self.lr.value(cur_batch), - self.lr.decay_steps_, - self.lr.decay_rate_, - self.lr.value(stop_batch), - ) + log.info( + "start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" + % ( + run_sess(self.sess, self.learning_rate), + self.lr.value(cur_batch), + self.lr.decay_steps_, + self.lr.decay_rate_, + self.lr.value(stop_batch), ) - else: - for fitting_key in self.fitting: - log.info( - "%s: start training at lr %.2e (== %.2e), decay_step %d, decay_rate %f, final lr will be %.2e" - % ( - fitting_key, - run_sess(self.sess, self.learning_rate_dict[fitting_key]), - self.lr_dict[fitting_key].value(cur_batch), - self.lr_dict[fitting_key].decay_steps_, - self.lr_dict[fitting_key].decay_rate_, - self.lr_dict[fitting_key].value(stop_batch), - ) - ) + ) prf_options = None prf_run_metadata = None @@ -623,88 +472,33 @@ def train(self, train_data=None, valid_data=None): next_datasetloader = None # dataset loader op - if not self.multi_task_mode: - datasetloader = DatasetLoader(train_data) - data_op = datasetloader.build() - else: - datasetloader = {} - data_op = {} - for fitting_key in self.fitting: - datasetloader[fitting_key] = DatasetLoader(train_data[fitting_key]) - data_op[fitting_key] = datasetloader[fitting_key].build() + datasetloader = DatasetLoader(train_data) + data_op = datasetloader.build() while cur_batch < stop_batch: # first round validation: if is_first_step: - if not self.multi_task_mode: - train_batch = train_data.get_batch() - batch_train_op = self.train_op - else: - fitting_idx = dp_random.choice( - np.arange(len(self.fitting_key_list)), - p=np.array(self.fitting_prob), - ) - fitting_key = self.fitting_key_list[fitting_idx] - train_batch = train_data[fitting_key].get_batch() - batch_train_op = self.train_op[fitting_key] + train_batch = train_data.get_batch() + batch_train_op = self.train_op else: train_batch = next_datasetloader.get_data_dict(next_train_batch_list) batch_train_op = next_batch_train_op fitting_key = next_fitting_key # for next round - if not self.multi_task_mode: - next_datasetloader = datasetloader - next_batch_train_op = self.train_op - next_train_batch_op = data_op - else: - fitting_idx = dp_random.choice( - np.arange(len(self.fitting_key_list)), p=np.array(self.fitting_prob) - ) - next_fitting_key = self.fitting_key_list[fitting_idx] - next_datasetloader = datasetloader[next_fitting_key] - next_batch_train_op = self.train_op[fitting_key] - next_train_batch_op = data_op[fitting_key] + next_datasetloader = datasetloader + next_batch_train_op = self.train_op + next_train_batch_op = data_op if self.display_in_training and is_first_step: if self.run_opt.is_chief: - if not self.multi_task_mode: - valid_batches = ( - [ - valid_data.get_batch() - for ii in range(self.valid_numb_batch) - ] - if valid_data is not None - else None - ) - self.valid_on_the_fly( - fp, [train_batch], valid_batches, print_header=True - ) - else: - train_batches = {} - valid_batches = {} - # valid_numb_batch_dict - for fitting_key_ii in train_data: - # enumerate fitting key as fitting_key_ii - train_batches[fitting_key_ii] = [ - train_data[fitting_key_ii].get_batch() - ] - valid_batches[fitting_key_ii] = ( - [ - valid_data[fitting_key_ii].get_batch() - for ii in range( - self.valid_numb_batch_dict[fitting_key_ii] - ) - ] - if fitting_key_ii in valid_data - else None - ) - self.valid_on_the_fly( - fp, - train_batches, - valid_batches, - print_header=True, - fitting_key=fitting_key, - ) + valid_batches = ( + [valid_data.get_batch() for ii in range(self.valid_numb_batch)] + if valid_data is not None + else None + ) + self.valid_on_the_fly( + fp, [train_batch], valid_batches, print_header=True + ) is_first_step = False if self.timing_in_training: @@ -741,36 +535,12 @@ def train(self, train_data=None, valid_data=None): if self.timing_in_training: tic = time.time() if self.run_opt.is_chief: - if not self.multi_task_mode: - valid_batches = ( - [ - valid_data.get_batch() - for ii in range(self.valid_numb_batch) - ] - if valid_data is not None - else None - ) - self.valid_on_the_fly(fp, [train_batch], valid_batches) - else: - train_batches = {} - valid_batches = {} - for fitting_key_ii in train_data: - train_batches[fitting_key_ii] = [ - train_data[fitting_key_ii].get_batch() - ] - valid_batches[fitting_key_ii] = ( - [ - valid_data[fitting_key_ii].get_batch() - for ii in range( - self.valid_numb_batch_dict[fitting_key_ii] - ) - ] - if fitting_key_ii in valid_data - else None - ) - self.valid_on_the_fly( - fp, train_batches, valid_batches, fitting_key=fitting_key - ) + valid_batches = ( + [valid_data.get_batch() for ii in range(self.valid_numb_batch)] + if valid_data is not None + else None + ) + self.valid_on_the_fly(fp, [train_batch], valid_batches) if self.timing_in_training: toc = time.time() test_time = toc - tic @@ -873,69 +643,30 @@ def valid_on_the_fly( valid_results = self.get_evaluation_results(valid_batches) cur_batch = self.cur_batch - if not self.multi_task_mode: - current_lr = run_sess(self.sess, self.learning_rate) - else: - assert ( - fitting_key is not None - ), "Fitting key must be assigned in validation!" - current_lr = None - # current_lr can be used as the learning rate of descriptor in the future - current_lr_dict = {} - for fitting_key_ii in train_batches: - current_lr_dict[fitting_key_ii] = run_sess( - self.sess, self.learning_rate_dict[fitting_key_ii] - ) + current_lr = run_sess(self.sess, self.learning_rate) if print_header: - self.print_header(fp, train_results, valid_results, self.multi_task_mode) - if not self.multi_task_mode: - self.print_on_training( - fp, - train_results, - valid_results, - cur_batch, - current_lr, - self.multi_task_mode, - ) - else: - assert ( - fitting_key is not None - ), "Fitting key must be assigned when printing learning rate!" - self.print_on_training( - fp, - train_results, - valid_results, - cur_batch, - current_lr, - self.multi_task_mode, - current_lr_dict, - ) + self.print_header(fp, train_results, valid_results) + self.print_on_training( + fp, + train_results, + valid_results, + cur_batch, + current_lr, + ) @staticmethod - def print_header(fp, train_results, valid_results, multi_task_mode=False): + def print_header(fp, train_results, valid_results): print_str = "" print_str += "# %5s" % "step" - if not multi_task_mode: - if valid_results is not None: - prop_fmt = " %11s %11s" - for k in train_results.keys(): - print_str += prop_fmt % (k + "_val", k + "_trn") - else: - prop_fmt = " %11s" - for k in train_results.keys(): - print_str += prop_fmt % (k + "_trn") - print_str += " %8s\n" % "lr" + if valid_results is not None: + prop_fmt = " %11s %11s" + for k in train_results.keys(): + print_str += prop_fmt % (k + "_val", k + "_trn") else: - for fitting_key in train_results: - if valid_results[fitting_key] is not None: - prop_fmt = " %11s %11s" - for k in train_results[fitting_key].keys(): - print_str += prop_fmt % (k + "_val", k + "_trn") - else: - prop_fmt = " %11s" - for k in train_results[fitting_key].keys(): - print_str += prop_fmt % (k + "_trn") - print_str += " %8s\n" % (fitting_key + "_lr") + prop_fmt = " %11s" + for k in train_results.keys(): + print_str += prop_fmt % (k + "_trn") + print_str += " %8s\n" % "lr" print_str += "# If there is no available reference data, rmse_*_{val,trn} will print nan\n" fp.write(print_str) fp.flush() @@ -947,71 +678,36 @@ def print_on_training( valid_results, cur_batch, cur_lr, - multi_task_mode=False, - cur_lr_dict=None, ): print_str = "" print_str += "%7d" % cur_batch - if not multi_task_mode: - if valid_results is not None: - prop_fmt = " %11.2e %11.2e" - for k in valid_results.keys(): - # assert k in train_results.keys() - print_str += prop_fmt % (valid_results[k], train_results[k]) - else: - prop_fmt = " %11.2e" - for k in train_results.keys(): - print_str += prop_fmt % (train_results[k]) - print_str += f" {cur_lr:8.1e}\n" + if valid_results is not None: + prop_fmt = " %11.2e %11.2e" + for k in valid_results.keys(): + # assert k in train_results.keys() + print_str += prop_fmt % (valid_results[k], train_results[k]) + else: + prop_fmt = " %11.2e" + for k in train_results.keys(): + print_str += prop_fmt % (train_results[k]) + print_str += f" {cur_lr:8.1e}\n" + log.info( + format_training_message_per_task( + batch=cur_batch, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) + ) + if valid_results is not None: log.info( format_training_message_per_task( batch=cur_batch, - task_name="trn", - rmse=train_results, - learning_rate=cur_lr, + task_name="val", + rmse=valid_results, + learning_rate=None, ) ) - if valid_results is not None: - log.info( - format_training_message_per_task( - batch=cur_batch, - task_name="val", - rmse=valid_results, - learning_rate=None, - ) - ) - else: - for fitting_key in train_results: - if valid_results[fitting_key] is not None: - prop_fmt = " %11.2e %11.2e" - for k in valid_results[fitting_key].keys(): - # assert k in train_results[fitting_key].keys() - print_str += prop_fmt % ( - valid_results[fitting_key][k], - train_results[fitting_key][k], - ) - else: - prop_fmt = " %11.2e" - for k in train_results[fitting_key].keys(): - print_str += prop_fmt % (train_results[fitting_key][k]) - print_str += f" {cur_lr_dict[fitting_key]:8.1e}\n" - log.info( - format_training_message_per_task( - batch=cur_batch, - task_name=f"{fitting_key}_trn", - rmse=train_results[fitting_key], - learning_rate=cur_lr_dict[fitting_key], - ) - ) - if valid_results is not None: - log.info( - format_training_message_per_task( - batch=cur_batch, - task_name=f"{fitting_key}_val", - rmse=valid_results[fitting_key], - learning_rate=None, - ) - ) fp.write(print_str) fp.flush() @@ -1041,20 +737,9 @@ def eval_single_list(single_batch_list, loss, sess, get_feed_dict_func, prefix=" return single_results def get_evaluation_results(self, batch_list): - if not self.multi_task_mode: - avg_results = self.eval_single_list( - batch_list, self.loss, self.sess, self.get_feed_dict - ) - else: - avg_results = {} - for fitting_key in batch_list: - avg_results[fitting_key] = self.eval_single_list( - batch_list[fitting_key], - self.loss_dict[fitting_key], - self.sess, - self.get_feed_dict, - prefix=f"{fitting_key}_", - ) + avg_results = self.eval_single_list( + batch_list, self.loss, self.sess, self.get_feed_dict + ) return avg_results def save_compressed(self): diff --git a/deepmd/tf/utils/multi_init.py b/deepmd/tf/utils/multi_init.py deleted file mode 100644 index aef4bf4af9..0000000000 --- a/deepmd/tf/utils/multi_init.py +++ /dev/null @@ -1,170 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import json -import logging -from typing import ( - Any, - Dict, -) - -from deepmd.tf.utils.errors import ( - GraphWithoutTensorError, -) -from deepmd.tf.utils.graph import ( - get_tensor_by_name, -) - -log = logging.getLogger(__name__) - - -def replace_model_params_with_frz_multi_model( - jdata: Dict[str, Any], pretrained_model: str -): - """Replace the model params in input script according to pretrained frozen multi-task united model. - - Parameters - ---------- - jdata : Dict[str, Any] - input script - pretrained_model : str - filename of the pretrained frozen multi-task united model - """ - # Get the input script from the pretrained model - try: - t_jdata = get_tensor_by_name(pretrained_model, "train_attr/training_script") - except GraphWithoutTensorError as e: - raise RuntimeError( - f"The input frozen pretrained model: {input} has no training script, " - "which is not supported to perform multi-task training. " - "Please use the model pretrained with v2.1.5 or higher version of DeePMD-kit." - ) from e - pretrained_jdata = json.loads(t_jdata) - - # Check the model type - assert "fitting_net_dict" in pretrained_jdata["model"], ( - "The multi-task init process only supports models trained in multi-task mode and frozen into united model!" - "Please use '--united-model' argument in 'dp freeze' command." - ) - - # Check the type map - pretrained_type_map = pretrained_jdata["model"]["type_map"] - cur_type_map = jdata["model"].get("type_map", []) - out_line_type = [] - for i in cur_type_map: - if i not in pretrained_type_map: - out_line_type.append(i) - assert not out_line_type, ( - f"{out_line_type!s} type(s) not contained in the pretrained model! " - "Please choose another suitable one." - ) - if cur_type_map != pretrained_type_map: - log.info( - f"Change the type_map from {cur_type_map!s} to {pretrained_type_map!s}." - ) - jdata["model"]["type_map"] = pretrained_type_map - - # Change model configurations - pretrained_fitting_keys = sorted( - pretrained_jdata["model"]["fitting_net_dict"].keys() - ) - cur_fitting_keys = sorted(jdata["model"]["fitting_net_dict"].keys()) - newly_added_fittings = set(cur_fitting_keys) - set(pretrained_fitting_keys) - reused_fittings = set(cur_fitting_keys) - newly_added_fittings - log.info("Change the model configurations according to the pretrained one...") - - for config_key in ["type_embedding", "descriptor", "fitting_net_dict"]: - if ( - config_key not in jdata["model"].keys() - and config_key in pretrained_jdata["model"].keys() - ): - log.info( - "Add the '{}' from pretrained model: {}.".format( - config_key, str(pretrained_jdata["model"][config_key]) - ) - ) - jdata["model"][config_key] = pretrained_jdata["model"][config_key] - elif ( - config_key == "type_embedding" - and config_key in jdata["model"].keys() - and config_key not in pretrained_jdata["model"].keys() - ): - # 'type_embedding' can be omitted using 'se_atten' descriptor, and the activation_function will be None. - cur_para = jdata["model"].pop(config_key) - if "trainable" in cur_para and not cur_para["trainable"]: - jdata["model"][config_key] = { - "trainable": False, - "activation_function": "None", - } - log.info("The type_embeddings from pretrained model will be frozen.") - elif config_key == "fitting_net_dict": - if reused_fittings: - log.info( - f"These fitting nets will use the configurations from pretrained frozen model : {reused_fittings}." - ) - for fitting_key in reused_fittings: - _change_sub_config( - jdata["model"][config_key], - pretrained_jdata["model"][config_key], - fitting_key, - ) - if newly_added_fittings: - log.info( - f"These fitting nets will be initialized from scratch : {newly_added_fittings}." - ) - elif ( - config_key in jdata["model"].keys() - and config_key in pretrained_jdata["model"].keys() - and jdata["model"][config_key] != pretrained_jdata["model"][config_key] - ): - _change_sub_config(jdata["model"], pretrained_jdata["model"], config_key) - - # Change other multi-task configurations - log.info("Change the training configurations according to the pretrained one...") - for config_key in ["loss_dict", "training/data_dict"]: - cur_jdata = jdata - target_jdata = pretrained_jdata - for sub_key in config_key.split("/"): - cur_jdata = cur_jdata[sub_key] - target_jdata = target_jdata[sub_key] - for fitting_key in reused_fittings: - if fitting_key not in cur_jdata: - target_para = target_jdata[fitting_key] - cur_jdata[fitting_key] = target_para - log.info( - f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." - ) - - # learning rate dict keep backward compatibility - config_key = "learning_rate_dict" - single_config_key = "learning_rate" - cur_jdata = jdata - target_jdata = pretrained_jdata - if (single_config_key not in cur_jdata) and (config_key in cur_jdata): - cur_jdata = cur_jdata[config_key] - if config_key in target_jdata: - target_jdata = target_jdata[config_key] - for fitting_key in reused_fittings: - if fitting_key not in cur_jdata: - target_para = target_jdata[fitting_key] - cur_jdata[fitting_key] = target_para - log.info( - f"Add '{config_key}/{fitting_key}' configurations from the pretrained frozen model." - ) - else: - for fitting_key in reused_fittings: - if fitting_key not in cur_jdata: - cur_jdata[fitting_key] = {} - log.info( - f"Add '{config_key}/{fitting_key}' configurations as default." - ) - - return jdata - - -def _change_sub_config(jdata: Dict[str, Any], src_jdata: Dict[str, Any], sub_key: str): - target_para = src_jdata[sub_key] - cur_para = jdata[sub_key] - # TODO: keep some params that are irrelevant to model structures (need to discuss) - if "trainable" in cur_para.keys(): - target_para["trainable"] = cur_para["trainable"] - log.info(f"Change the '{sub_key}' from {cur_para!s} to {target_para!s}.") - jdata[sub_key] = target_para diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index fb0e0855b8..b1f5cda0b2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1508,7 +1508,6 @@ def model_args(exclude_hybrid=False): "type", [ standard_model_args(), - multi_model_args(), frozen_model_args(), pairtab_model_args(), *hybrid_models, @@ -1544,29 +1543,6 @@ def standard_model_args() -> Argument: return ca -def multi_model_args() -> Argument: - doc_descrpt = "The descriptor of atomic environment. See model[standard]/descriptor for details." - doc_fitting_net_dict = "The dictionary of multiple fitting nets in multi-task mode. Each fitting_net_dict[fitting_key] is the single definition of fitting of physical properties with user-defined name `fitting_key`." - - ca = Argument( - "multi", - dict, - [ - Argument( - "descriptor", - dict, - [], - [descrpt_variant_type_args()], - doc=doc_descrpt, - fold_subdoc=True, - ), - Argument("fitting_net_dict", dict, doc=doc_fitting_net_dict), - ], - doc=doc_only_tf_supported + "Multiple-task model.", - ) - return ca - - def pairwise_dprc() -> Argument: qm_model_args = model_args(exclude_hybrid=True) qm_model_args.name = "qm_model" @@ -1713,17 +1689,6 @@ def learning_rate_args(): ) -def learning_rate_dict_args(): - doc_learning_rate_dict = ( - "The dictionary of definitions of learning rates in multi-task mode. " - "Each learning_rate_dict[fitting_key], with user-defined name `fitting_key` in `model/fitting_net_dict`, is the single definition of learning rate.\n" - ) - ca = Argument( - "learning_rate_dict", dict, [], [], optional=True, doc=doc_learning_rate_dict - ) - return ca - - # --- Loss configurations: --- # def start_pref(item, label=None, abbr=None): if label is None: @@ -2089,15 +2054,6 @@ def loss_args(): return ca -def loss_dict_args(): - doc_loss_dict = ( - "The dictionary of definitions of multiple loss functions in multi-task mode. " - "Each loss_dict[fitting_key], with user-defined name `fitting_key` in `model/fitting_net_dict`, is the single definition of loss function, whose type should be set to `tensor`, `ener` or left unset.\n" - ) - ca = Argument("loss_dict", dict, [], [], optional=True, doc=doc_loss_dict) - return ca - - # --- Training configurations: --- # def training_data_args(): # ! added by Ziyao: new specification style for data systems. link_sys = make_link("systems", "training/training_data/systems") @@ -2294,18 +2250,6 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events." - doc_data_dict = ( - "The dictionary of multi DataSystems in multi-task mode. " - "Each data_dict[fitting_key], with user-defined name `fitting_key` in `model/fitting_net_dict`, " - "contains training data and optional validation data definitions." - ) - doc_fitting_weight = ( - "Each fitting_weight[fitting_key], with user-defined name `fitting_key` in `model/fitting_net_dict`, " - "is the training weight of fitting net `fitting_key`. " - "Fitting nets with higher weights will be selected with higher probabilities to be trained in one step. " - "Weights will be normalized and minus ones will be ignored. " - "If not set, each fitting net will be equally selected when training." - ) doc_warmup_steps = ( "The number of steps for learning rate warmup. During warmup, " "the learning rate begins at zero and progressively increases linearly to `start_lr`, " @@ -2387,8 +2331,6 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. Argument( "tensorboard_freq", int, optional=True, default=1, doc=doc_tensorboard_freq ), - Argument("data_dict", dict, optional=True, doc=doc_data_dict), - Argument("fitting_weight", dict, optional=True, doc=doc_fitting_weight), Argument( "warmup_steps", int, @@ -2469,211 +2411,13 @@ def gen_args(**kwargs) -> List[Argument]: return [ model_args(), learning_rate_args(), - learning_rate_dict_args(), loss_args(), - loss_dict_args(), training_args(), nvnmd_args(), ] -def normalize_multi_task(data): - # single-task or multi-task mode - if data["model"].get("type", "standard") not in ("standard", "multi"): - return data - single_fitting_net = "fitting_net" in data["model"].keys() - single_training_data = "training_data" in data["training"].keys() - single_valid_data = "validation_data" in data["training"].keys() - single_loss = "loss" in data.keys() - single_learning_rate = "learning_rate" in data.keys() - multi_fitting_net = "fitting_net_dict" in data["model"].keys() - multi_training_data = "data_dict" in data["training"].keys() - multi_loss = "loss_dict" in data.keys() - multi_fitting_weight = "fitting_weight" in data["training"].keys() - multi_learning_rate = "learning_rate_dict" in data.keys() - assert (single_fitting_net == single_training_data) and ( - multi_fitting_net == multi_training_data - ), ( - "In single-task mode, 'model/fitting_net' and 'training/training_data' must be defined at the same time! " - "While in multi-task mode, 'model/fitting_net_dict', 'training/data_dict' " - "must be defined at the same time! Please check your input script. " - ) - assert not (single_fitting_net and multi_fitting_net), ( - "Single-task mode and multi-task mode can not be performed together. " - "Please check your input script and choose just one format! " - ) - assert ( - single_fitting_net or multi_fitting_net - ), "Please define your fitting net and training data! " - if multi_fitting_net: - assert not single_valid_data, ( - "In multi-task mode, 'training/validation_data' should not appear " - "outside 'training/data_dict'! Please check your input script." - ) - assert ( - not single_loss - ), "In multi-task mode, please use 'model/loss_dict' in stead of 'model/loss'! " - assert ( - "type_map" in data["model"] - ), "In multi-task mode, 'model/type_map' must be defined! " - data["model"]["type"] = "multi" - data["model"]["fitting_net_dict"] = normalize_fitting_net_dict( - data["model"]["fitting_net_dict"] - ) - data["training"]["data_dict"] = normalize_data_dict( - data["training"]["data_dict"] - ) - data["loss_dict"] = ( - normalize_loss_dict( - data["model"]["fitting_net_dict"].keys(), data["loss_dict"] - ) - if multi_loss - else {} - ) - if multi_learning_rate: - data["learning_rate_dict"] = normalize_learning_rate_dict( - data["model"]["fitting_net_dict"].keys(), data["learning_rate_dict"] - ) - elif single_learning_rate: - data["learning_rate_dict"] = ( - normalize_learning_rate_dict_with_single_learning_rate( - data["model"]["fitting_net_dict"].keys(), data["learning_rate"] - ) - ) - fitting_weight = ( - data["training"]["fitting_weight"] if multi_fitting_weight else None - ) - data["training"]["fitting_weight"] = normalize_fitting_weight( - data["model"]["fitting_net_dict"].keys(), - data["training"]["data_dict"].keys(), - fitting_weight=fitting_weight, - ) - else: - assert not multi_loss, "In single-task mode, please use 'model/loss' in stead of 'model/loss_dict'! " - assert not multi_learning_rate, "In single-task mode, please use 'model/learning_rate' in stead of 'model/learning_rate_dict'! " - return data - - -def normalize_fitting_net_dict(fitting_net_dict): - new_dict = {} - base = Argument("base", dict, [], [fitting_variant_type_args()], doc="") - for fitting_key_item in fitting_net_dict: - data = base.normalize_value( - fitting_net_dict[fitting_key_item], trim_pattern="_*" - ) - base.check_value(data, strict=True) - new_dict[fitting_key_item] = data - return new_dict - - -def normalize_data_dict(data_dict): - new_dict = {} - base = Argument( - "base", dict, [training_data_args(), validation_data_args()], [], doc="" - ) - for data_system_key_item in data_dict: - data = base.normalize_value(data_dict[data_system_key_item], trim_pattern="_*") - base.check_value(data, strict=True) - new_dict[data_system_key_item] = data - return new_dict - - -def normalize_loss_dict(fitting_keys, loss_dict): - # check the loss dict - failed_loss_keys = [item for item in loss_dict if item not in fitting_keys] - assert not failed_loss_keys, f"Loss dict key(s) {failed_loss_keys!s} not have corresponding fitting keys in {list(fitting_keys)!s}! " - new_dict = {} - base = Argument("base", dict, [], [loss_variant_type_args()], doc="") - for item in loss_dict: - data = base.normalize_value(loss_dict[item], trim_pattern="_*") - base.check_value(data, strict=True) - new_dict[item] = data - return new_dict - - -def normalize_learning_rate_dict(fitting_keys, learning_rate_dict): - # check the learning_rate dict - failed_learning_rate_keys = [ - item for item in learning_rate_dict if item not in fitting_keys - ] - assert not failed_learning_rate_keys, f"Learning rate dict key(s) {failed_learning_rate_keys!s} not have corresponding fitting keys in {list(fitting_keys)!s}! " - new_dict = {} - base = Argument("base", dict, [], [learning_rate_variant_type_args()], doc="") - for item in learning_rate_dict: - data = base.normalize_value(learning_rate_dict[item], trim_pattern="_*") - base.check_value(data, strict=True) - new_dict[item] = data - return new_dict - - -def normalize_learning_rate_dict_with_single_learning_rate(fitting_keys, learning_rate): - new_dict = {} - base = Argument("base", dict, [], [learning_rate_variant_type_args()], doc="") - data = base.normalize_value(learning_rate, trim_pattern="_*") - base.check_value(data, strict=True) - for fitting_key in fitting_keys: - new_dict[fitting_key] = data - return new_dict - - -def normalize_fitting_weight(fitting_keys, data_keys, fitting_weight=None): - # check the mapping - failed_data_keys = [item for item in data_keys if item not in fitting_keys] - assert not failed_data_keys, f"Data dict key(s) {failed_data_keys!s} not have corresponding fitting keys in {list(fitting_keys)!s}! " - empty_fitting_keys = [] - valid_fitting_keys = [] - for item in fitting_keys: - if item not in data_keys: - empty_fitting_keys.append(item) - else: - valid_fitting_keys.append(item) - if empty_fitting_keys: - log.warning( - f"Fitting net(s) {empty_fitting_keys!s} have no data and will not be used in training." - ) - num_pair = len(valid_fitting_keys) - assert num_pair > 0, "No valid training data systems for fitting nets!" - - # check and normalize the fitting weight - new_weight = {} - if fitting_weight is None: - equal_weight = 1.0 / num_pair - for item in fitting_keys: - new_weight[item] = equal_weight if item in valid_fitting_keys else 0.0 - else: - failed_weight_keys = [ - item for item in fitting_weight if item not in fitting_keys - ] - assert not failed_weight_keys, f"Fitting weight key(s) {failed_weight_keys!s} not have corresponding fitting keys in {list(fitting_keys)!s}! " - sum_prob = 0.0 - for item in fitting_keys: - if item in valid_fitting_keys: - if ( - item in fitting_weight - and isinstance(fitting_weight[item], (int, float)) - and fitting_weight[item] > 0.0 - ): - sum_prob += fitting_weight[item] - new_weight[item] = fitting_weight[item] - else: - valid_fitting_keys.remove(item) - log.warning( - f"Fitting net '{item}' has zero or invalid weight " - "and will not be used in training." - ) - new_weight[item] = 0.0 - else: - new_weight[item] = 0.0 - assert sum_prob > 0.0, "No valid training weight for fitting nets!" - # normalize - for item in new_weight: - new_weight[item] /= sum_prob - return new_weight - - def normalize(data): - data = normalize_multi_task(data) - base = Argument("base", dict, gen_args()) data = base.normalize_value(data, trim_pattern="_*") base.check_value(data, strict=True) diff --git a/doc/freeze/freeze.md b/doc/freeze/freeze.md index b80928a119..5bd63a4840 100644 --- a/doc/freeze/freeze.md +++ b/doc/freeze/freeze.md @@ -22,14 +22,11 @@ $ dp --pt freeze -o model.pth in the folder where the model is trained. The output model is called `model.pth`. -::: - -:::: +In [multi-task mode](../train/multi-task-training-pt.md), you need to choose one available heads (e.g. `CHOSEN_BRANCH`) by `--head` +to specify which model branch you want to freeze: -In [multi-task mode](../train/multi-task-training.md): +```bash +$ dp --pt freeze -o model_branch1.pth --head CHOSEN_BRANCH +``` -- This process will in default output several models, each of which contains the common descriptor and - one of the user-defined fitting nets in {ref}`fitting_net_dict `, let's name it `fitting_key`, together frozen in `graph_{fitting_key}.pb`. - Those frozen models are exactly the same as single-task output with fitting net `fitting_key`. -- If you add `--united-model` option in this situation, - the total multi-task model will be frozen into one unit `graph.pb`, which is mainly for multi-task initialization and can not be used directly for inference. +The output model is called `model_branch1.pth`, which is the specifically frozen model with the `CHOSEN_BRANCH` head. diff --git a/doc/train/multi-task-training-tf.md b/doc/train/multi-task-training-tf.md index 48a9ef44e9..9c19025f22 100644 --- a/doc/train/multi-task-training-tf.md +++ b/doc/train/multi-task-training-tf.md @@ -1,162 +1,5 @@ # Multi-task training {{ tensorflow_icon }} -:::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +:::{warning} +We have deprecated TensorFlow backend multi-task training, please use the PyTorch one [here](multi-task-training-pt.md). ::: - - - -## Theory - -The multi-task training process can simultaneously handle different datasets with properties that cannot be fitted in one network (e.g. properties from DFT calculations under different exchange-correlation functionals or different basis sets). -These datasets are denoted by $\boldsymbol x^{(1)}, \dots, \boldsymbol x^{(n_t)}$. -For each dataset, a training task is defined as - -```math - \min_{\boldsymbol \theta} L^{(t)} (\boldsymbol x^{(t)}; \boldsymbol \theta^{(t)}, \tau), \quad t=1, \dots, n_t. -``` - -During the multi-task training process, all tasks share one descriptor with trainable parameters $\boldsymbol{\theta}_ {d}$, while each of them has its own fitting network with trainable parameters $\boldsymbol{\theta}_ f^{(t)}$, thus -$\boldsymbol{\theta}^{(t)} = \{ \boldsymbol{\theta}_ {d} , \boldsymbol{\theta}_ {f}^{(t)} \}$. -At each training step, a task is randomly picked from ${1, \dots, n_t}$, and the Adam optimizer is executed to minimize $L^{(t)}$ for one step to update the parameter $\boldsymbol \theta^{(t)}$. -If different fitting networks have the same architecture, they can share the parameters of some layers -to improve training efficiency.[^1] - -[^1]: This section is built upon Jinzhe Zeng, Duo Zhang, Denghui Lu, Pinghui Mo, Zeyu Li, Yixiao Chen, Marián Rynik, Li'ang Huang, Ziyao Li, Shaochen Shi, Yingze Wang, Haotian Ye, Ping Tuo, Jiabin Yang, Ye Ding, Yifan Li, Davide Tisi, Qiyu Zeng, Han Bao, Yu Xia, Jiameng Huang, Koki Muraoka, Yibo Wang, Junhan Chang, Fengbo Yuan, Sigbjørn Løland Bore, Chun Cai, Yinnian Lin, Bo Wang, Jiayan Xu, Jia-Xin Zhu, Chenxing Luo, Yuzhi Zhang, Rhys E. A. Goodall, Wenshuo Liang, Anurag Kumar Singh, Sikai Yao, Jingchao Zhang, Renata Wentzcovitch, Jiequn Han, Jie Liu, Weile Jia, Darrin M. York, Weinan E, Roberto Car, Linfeng Zhang, Han Wang, [J. Chem. Phys. 159, 054801 (2023)](https://doi.org/10.1063/5.0155600) licensed under a [Creative Commons Attribution (CC BY) license](http://creativecommons.org/licenses/by/4.0/). - -## Perform the multi-task training - -Training on multiple data sets (each data set contains several data systems) can be performed in multi-task mode, -with one common descriptor and multiple specific fitting nets for each data set. -One can simply switch the following parameters in training input script to perform multi-task mode: - -- {ref}`fitting_net ` --> {ref}`fitting_net_dict `, - each key of which can be one individual fitting net. -- {ref}`training_data `, {ref}`validation_data ` - --> {ref}`data_dict `, each key of which can be one individual data set contains - several data systems for corresponding fitting net, the keys must be consistent with those in - {ref}`fitting_net_dict `. -- {ref}`loss ` --> {ref}`loss_dict `, each key of which can be one individual loss setting - for corresponding fitting net, the keys must be consistent with those in - {ref}`fitting_net_dict `, if not set, the corresponding fitting net will use the default loss. -- (Optional) {ref}`fitting_weight `, each key of which can be a non-negative integer or float, - deciding the chosen probability for corresponding fitting net in training, if not set or invalid, - the corresponding fitting net will not be used. - -The training procedure will automatically choose single-task or multi-task mode, based on the above parameters. -Note that parameters of single-task mode and multi-task mode can not be mixed. - -An example input for training energy and dipole in water system can be found here: [multi-task input on water](../../examples/water_multi_task/ener_dipole/input.json). - -The supported descriptors for multi-task mode are listed: - -- {ref}`se_a (se_e2_a) ` -- {ref}`se_r (se_e2_r) ` -- {ref}`se_at (se_e3) ` -- {ref}`se_atten ` -- {ref}`se_atten_v2 ` -- {ref}`hybrid ` - -The supported fitting nets for multi-task mode are listed: - -- {ref}`ener ` -- {ref}`dipole ` -- {ref}`polar ` - -The output of `dp freeze` command in multi-task mode can be seen in [freeze command](../freeze/freeze.md). - -## Initialization from pre-trained multi-task model - -For advance training in multi-task mode, one can first train the descriptor on several upstream datasets and then transfer it on new downstream ones with newly added fitting nets. -At the second step, you can also inherit some fitting nets trained on upstream datasets, by merely adding fitting net keys in {ref}`fitting_net_dict ` and -optional fitting net weights in {ref}`fitting_weight `. - -Take [multi-task input on water](../../examples/water_multi_task/ener_dipole/input.json) again for example. -You can first train a multi-task model using input script with the following {ref}`model ` part: - -```json - "model": { - "type_map": ["O", "H"], - "descriptor": { - "type": "se_e2_a", - "sel": [46, 92], - "rcut_smth": 0.5, - "rcut": 6.0, - "neuron": [25, 50, 100], - "type_one_side": true - }, - "fitting_net_dict": { - "water_dipole": { - "type": "dipole", - "neuron": [100, 100, 100] - }, - "water_ener": { - "neuron": [240, 240, 240], - "resnet_dt": true - } - }, - } -``` - -After training, you can freeze this multi-task model into one unit graph: - -```bash -$ dp freeze -o graph.pb --united-model -``` - -Then if you want to transfer the trained descriptor and some fitting nets (take `water_ener` for example) to newly added datasets with new fitting net `water_ener_2`, -you can modify the {ref}`model ` part of the new input script in a more simplified way: - -```json - "model": { - "type_map": ["O", "H"], - "descriptor": {}, - "fitting_net_dict": { - "water_ener": {}, - "water_ener_2": { - "neuron": [240, 240, 240], - "resnet_dt": true, - } - }, - } -``` - -It will autocomplete the configurations according to the frozen graph. - -Note that for newly added fitting net keys, other parts in the input script, including {ref}`data_dict ` and {ref}`loss_dict ` (optionally {ref}`fitting_weight `), -should be set explicitly. While for old fitting net keys, it will inherit the old configurations if not set. - -Finally, you can perform the modified multi-task training from the frozen model with command: - -```bash -$ dp train input.json --init_frz_model graph.pb -``` - -## Share layers among energy fitting networks - -The multi-task training can be used to train multiple levels of energies (e.g. DFT and CCSD(T)) at the same time. -In this situation, one can set {ref}`model/fitting_net[ener]/layer_name>` to share some of layers among fitting networks. -The architecture of the layers with the same name should be the same. - -For example, if one want to share the first and the third layers for two three-hidden-layer fitting networks, the following parameters should be set. - -```json -"fitting_net_dict": { - "ccsd": { - "neuron": [ - 240, - 240, - 240 - ], - "layer_name": ["l0", null, "l2", null] - }, - "wb97m": { - "neuron": [ - 240, - 240, - 240 - ], - "layer_name": ["l0", null, "l2", null] - } -} -``` diff --git a/examples/water_multi_task/ener_dipole/input.json b/examples/water_multi_task/ener_dipole/input.json deleted file mode 100644 index 45b49c5d90..0000000000 --- a/examples/water_multi_task/ener_dipole/input.json +++ /dev/null @@ -1,135 +0,0 @@ -{ - "_comment1": "that's all", - "model": { - "type_map": [ - "O", - "H" - ], - "descriptor": { - "type": "se_e2_a", - "sel": [ - 46, - 92 - ], - "rcut_smth": 0.5, - "rcut": 6.0, - "neuron": [ - 25, - 50, - 100 - ], - "resnet_dt": false, - "axis_neuron": 16, - "type_one_side": true, - "precision": "float64", - "seed": 1, - "_comment2": " that's all" - }, - "fitting_net_dict": { - "water_dipole": { - "type": "dipole", - "sel_type": [ - 0 - ], - "neuron": [ - 100, - 100, - 100 - ], - "resnet_dt": true, - "precision": "float64", - "seed": 1, - "_comment3": " that's all" - }, - "water_ener": { - "neuron": [ - 240, - 240, - 240 - ], - "resnet_dt": true, - "precision": "float64", - "seed": 1, - "_comment4": " that's all" - } - }, - "_comment5": " that's all" - }, - "learning_rate": { - "type": "exp", - "decay_steps": 5000, - "start_lr": 0.001, - "stop_lr": 3.51e-08, - "_comment6": "that's all" - }, - "loss_dict": { - "water_dipole": { - "type": "tensor", - "pref": 1.0, - "pref_atomic": 1.0, - "_comment7": " that's all" - }, - "water_ener": { - "type": "ener", - "start_pref_e": 0.02, - "limit_pref_e": 1, - "start_pref_f": 1000, - "limit_pref_f": 1, - "start_pref_v": 0, - "limit_pref_v": 0, - "_comment8": " that's all" - } - }, - "training": { - "data_dict": { - "water_dipole": { - "training_data": { - "systems": [ - "../../water_tensor/dipole/training_data/atomic_system", - "../../water_tensor/dipole/training_data/global_system" - ], - "batch_size": "auto", - "_comment9": "that's all" - }, - "validation_data": { - "systems": [ - "../../water_tensor/dipole/validation_data/atomic_system", - "../../water_tensor/dipole/validation_data/global_system" - ], - "batch_size": 1, - "numb_btch": 3, - "_comment10": "that's all" - } - }, - "water_ener": { - "training_data": { - "systems": [ - "../../water/data/data_0/", - "../../water/data/data_1/", - "../../water/data/data_2/" - ], - "batch_size": "auto", - "_comment11": "that's all" - }, - "validation_data": { - "systems": [ - "../../water/data/data_3/" - ], - "batch_size": 1, - "numb_btch": 3, - "_comment12": "that's all" - } - } - }, - "fitting_weight": { - "water_dipole": 10, - "water_ener": 20 - }, - "numb_steps": 1000000, - "seed": 10, - "disp_file": "lcurve.out", - "disp_freq": 100, - "save_freq": 1000, - "_comment13": "that's all" - } -} diff --git a/source/tests/common/test_examples.py b/source/tests/common/test_examples.py index 6d5e34fedf..f7f1593f6f 100644 --- a/source/tests/common/test_examples.py +++ b/source/tests/common/test_examples.py @@ -36,7 +36,6 @@ p_examples / "water_tensor" / "polar" / "polar_input.json", p_examples / "water_tensor" / "dipole" / "dipole_input_torch.json", p_examples / "water_tensor" / "polar" / "polar_input_torch.json", - p_examples / "water_multi_task" / "ener_dipole" / "input.json", p_examples / "fparam" / "train" / "input.json", p_examples / "fparam" / "train" / "input_aparam.json", p_examples / "zinc_protein" / "zinc_se_a_mask.json", diff --git a/source/tests/tf/test_init_frz_model_multi.py b/source/tests/tf/test_init_frz_model_multi.py deleted file mode 100644 index b6209a7e69..0000000000 --- a/source/tests/tf/test_init_frz_model_multi.py +++ /dev/null @@ -1,254 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import json -import os -import unittest - -import numpy as np - -from deepmd.tf.env import ( - GLOBAL_NP_FLOAT_PRECISION, - tf, -) -from deepmd.tf.train.run_options import ( - RunOptions, -) -from deepmd.tf.train.trainer import ( - DPTrainer, -) -from deepmd.tf.utils.argcheck import ( - normalize, -) -from deepmd.tf.utils.compat import ( - update_deepmd_input, -) -from deepmd.tf.utils.data_system import ( - DeepmdDataSystem, -) -from deepmd.tf.utils.multi_init import ( - replace_model_params_with_frz_multi_model, -) - -from .common import ( - j_loader, - run_dp, - tests_path, -) - -if GLOBAL_NP_FLOAT_PRECISION == np.float32: - default_places = 4 -else: - default_places = 10 - - -def _file_delete(file): - if os.path.isdir(file): - os.rmdir(file) - elif os.path.isfile(file): - os.remove(file) - - -def _init_models(): - data_file = str(tests_path / os.path.join("init_frz_model", "data")) - frozen_model = str(tests_path / "init_frz_multi_unit.pb") - ckpt = str(tests_path / "init_frz_multi.ckpt") - run_opt_ckpt = RunOptions(init_model=ckpt, log_level=20) - run_opt_frz = RunOptions(init_frz_model=frozen_model, log_level=20) - INPUT = str(tests_path / "input.json") - jdata = j_loader(str(tests_path / os.path.join("init_frz_model", "input.json"))) - fitting_config = jdata["model"].pop("fitting_net") - loss_config = jdata.pop("loss") - learning_rate_config = jdata.pop("learning_rate") - training_data_config = jdata["training"].pop("training_data") - validation_data_config = jdata["training"].pop("validation_data") - jdata["training"]["data_dict"] = {} - jdata["training"]["data_dict"]["water_ener"] = {} - jdata["training"]["data_dict"]["water_ener"]["training_data"] = training_data_config - jdata["training"]["data_dict"]["water_ener"]["training_data"]["systems"] = data_file - jdata["training"]["data_dict"]["water_ener"]["validation_data"] = ( - validation_data_config - ) - jdata["training"]["data_dict"]["water_ener"]["validation_data"]["systems"] = ( - data_file - ) - jdata["training"]["save_ckpt"] = ckpt - jdata["model"]["fitting_net_dict"] = {} - jdata["model"]["fitting_net_dict"]["water_ener"] = fitting_config - jdata["loss_dict"] = {} - jdata["loss_dict"]["water_ener"] = loss_config - jdata["learning_rate_dict"] = {} - jdata["learning_rate_dict"]["water_ener"] = learning_rate_config - with open(INPUT, "w") as fp: - json.dump(jdata, fp, indent=4) - ret = run_dp("dp train " + INPUT) - np.testing.assert_equal(ret, 0, "DP train failed!") - ret = run_dp( - "dp freeze -c " + str(tests_path) + " -o " + frozen_model + " --united-model" - ) - np.testing.assert_equal(ret, 0, "DP freeze failed!") - jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") - jdata = normalize(jdata) - model_ckpt = DPTrainer(jdata, run_opt=run_opt_ckpt) - - # change the multi-task branch - jdata["model"]["fitting_net_dict"]["water_ener"] = {} - jdata["model"]["fitting_net_dict"]["water_ener_new"] = fitting_config - jdata["loss_dict"] = {} - jdata["loss_dict"]["water_ener_new"] = loss_config - jdata["learning_rate_dict"] = {} - jdata["learning_rate_dict"]["water_ener_new"] = learning_rate_config - jdata["training"]["data_dict"] = {} - jdata["training"]["data_dict"]["water_ener_new"] = {} - jdata["training"]["data_dict"]["water_ener_new"]["training_data"] = ( - training_data_config - ) - jdata["training"]["data_dict"]["water_ener_new"]["training_data"]["systems"] = ( - data_file - ) - jdata["training"]["data_dict"]["water_ener_new"]["validation_data"] = ( - validation_data_config - ) - jdata["training"]["data_dict"]["water_ener_new"]["validation_data"]["systems"] = ( - data_file - ) - jdata["training"].pop("fitting_weight") - - jdata = replace_model_params_with_frz_multi_model(jdata, frozen_model) - jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") - if "validation_data" in jdata["training"]: - jdata["training"].pop("validation_data") - jdata = normalize(jdata) - model_frz = DPTrainer(jdata, run_opt=run_opt_frz) - rcut = model_ckpt.model.get_rcut() - type_map = model_ckpt.model.get_type_map() - data = DeepmdDataSystem( - systems=[data_file], - batch_size=1, - test_size=1, - rcut=rcut, - type_map=type_map, - trn_all_set=True, - ) - data_requirement = { - "energy": { - "ndof": 1, - "atomic": False, - "must": False, - "high_prec": True, - "type_sel": None, - "repeat": 1, - "default": 0.0, - }, - "force": { - "ndof": 3, - "atomic": True, - "must": False, - "high_prec": False, - "type_sel": None, - "repeat": 1, - "default": 0.0, - }, - "virial": { - "ndof": 9, - "atomic": False, - "must": False, - "high_prec": False, - "type_sel": None, - "repeat": 1, - "default": 0.0, - }, - "atom_ener": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - "type_sel": None, - "repeat": 1, - "default": 0.0, - }, - "atom_pref": { - "ndof": 1, - "atomic": True, - "must": False, - "high_prec": False, - "type_sel": None, - "repeat": 3, - "default": 0.0, - }, - } - data.add_dict(data_requirement) - stop_batch = jdata["training"]["numb_steps"] - - return INPUT, ckpt, frozen_model, model_ckpt, model_frz, data, stop_batch - - -class TestInitFrzModelMulti(unittest.TestCase): - @classmethod - def setUpClass(cls): - ( - cls.INPUT, - cls.CKPT, - cls.FROZEN_MODEL, - CKPT_TRAINER, - FRZ_TRAINER, - VALID_DATA, - STOP_BATCH, - ) = _init_models() - - cls.dp_ckpt = CKPT_TRAINER - cls.dp_frz = FRZ_TRAINER - cls.valid_data_dict = {"water_ener": VALID_DATA} - cls.valid_data_dict_new = { - "water_ener": VALID_DATA, - "water_ener_new": VALID_DATA, - } - cls.stop_batch = STOP_BATCH - - @classmethod - def tearDownClass(cls): - _file_delete(cls.INPUT) - _file_delete(cls.FROZEN_MODEL) - _file_delete("out.json") - _file_delete(str(tests_path / "checkpoint")) - _file_delete(cls.CKPT + ".meta") - _file_delete(cls.CKPT + ".index") - _file_delete(cls.CKPT + ".data-00000-of-00001") - _file_delete(cls.CKPT + "-0.meta") - _file_delete(cls.CKPT + "-0.index") - _file_delete(cls.CKPT + "-0.data-00000-of-00001") - _file_delete(cls.CKPT + "-1.meta") - _file_delete(cls.CKPT + "-1.index") - _file_delete(cls.CKPT + "-1.data-00000-of-00001") - _file_delete("input_v2_compat.json") - _file_delete("lcurve.out") - - def test_single_frame(self): - test_sys_name = "water_ener" - valid_batch = self.valid_data_dict[test_sys_name].get_batch() - natoms = valid_batch["natoms_vec"] - tf.reset_default_graph() - self.dp_ckpt.build(self.valid_data_dict, self.stop_batch) - self.dp_ckpt._init_session() - feed_dict_ckpt = self.dp_ckpt.get_feed_dict(valid_batch, is_training=False) - ckpt_rmse_ckpt = self.dp_ckpt.loss_dict[test_sys_name].eval( - self.dp_ckpt.sess, feed_dict_ckpt, natoms - ) - tf.reset_default_graph() - - self.dp_frz.build(self.valid_data_dict_new, self.stop_batch) - self.dp_frz._init_session() - feed_dict_frz = self.dp_frz.get_feed_dict(valid_batch, is_training=False) - ckpt_rmse_frz = self.dp_frz.loss_dict[test_sys_name].eval( - self.dp_frz.sess, feed_dict_frz, natoms - ) - tf.reset_default_graph() - - # check values - np.testing.assert_almost_equal( - ckpt_rmse_ckpt["rmse_e"], ckpt_rmse_frz["rmse_e"], default_places - ) - np.testing.assert_almost_equal( - ckpt_rmse_ckpt["rmse_f"], ckpt_rmse_frz["rmse_f"], default_places - ) - np.testing.assert_almost_equal( - ckpt_rmse_ckpt["rmse_v"], ckpt_rmse_frz["rmse_v"], default_places - ) diff --git a/source/tests/tf/test_layer_name.py b/source/tests/tf/test_layer_name.py deleted file mode 100644 index c61c5eafe2..0000000000 --- a/source/tests/tf/test_layer_name.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import numpy as np - -from deepmd.tf.common import ( - j_must_have, -) -from deepmd.tf.descriptor import ( - DescrptSeA, -) -from deepmd.tf.env import ( - tf, -) -from deepmd.tf.fit import ( - DipoleFittingSeA, - EnerFitting, -) -from deepmd.tf.model import ( - MultiModel, -) - -from .common import ( - DataSystem, - del_data, - gen_data, - j_loader, -) - -GLOBAL_ENER_FLOAT_PRECISION = tf.float64 -GLOBAL_TF_FLOAT_PRECISION = tf.float64 -GLOBAL_NP_FLOAT_PRECISION = np.float64 - - -class TestModel(tf.test.TestCase): - def setUp(self): - gen_data() - - def tearDown(self): - del_data() - - def test_model(self): - """Two fittings which share the same parameters should give the same result.""" - jfile = "water_layer_name.json" - jdata = j_loader(jfile) - - systems = j_must_have(jdata, "systems") - set_pfx = "set" - batch_size = j_must_have(jdata, "batch_size") - test_size = j_must_have(jdata, "numb_test") - batch_size = 1 - test_size = 1 - rcut = j_must_have(jdata["model"]["descriptor"], "rcut") - - data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) - - test_data = data.get_test() - numb_test = 1 - - jdata["model"]["descriptor"].pop("type", None) - jdata["model"]["descriptor"]["multi_task"] = True - descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) - fitting_dict = {} - fitting_type_dict = {} - for fitting_key in jdata["model"]["fitting_net_dict"]: - item_fitting_param = jdata["model"]["fitting_net_dict"][fitting_key] - item_fitting_type = item_fitting_param.get("type", "ener") - fitting_type_dict[fitting_key] = item_fitting_type - item_fitting_param.pop("type", None) - item_fitting_param.pop("fit_diag", None) - item_fitting_param["ntypes"] = descrpt.get_ntypes() - item_fitting_param["dim_descrpt"] = descrpt.get_dim_out() - if item_fitting_type == "ener": - fitting_dict[fitting_key] = EnerFitting( - **item_fitting_param, uniform_seed=True - ) - elif item_fitting_type == "dipole": - fitting_dict[fitting_key] = DipoleFittingSeA( - **item_fitting_param, uniform_seed=True - ) - else: - raise RuntimeError("Test should not be here!") - model = MultiModel(descrpt, fitting_dict, fitting_type_dict) - - input_data = { - "coord": [test_data["coord"]], - "box": [test_data["box"]], - "type": [test_data["type"]], - "natoms_vec": [test_data["natoms_vec"]], - "default_mesh": [test_data["default_mesh"]], - } - - for fitting_key in jdata["model"]["fitting_net_dict"]: - model._compute_input_stat(input_data, fitting_key=fitting_key) - model.descrpt.merge_input_stats(model.descrpt.stat_dict) - model.descrpt.bias_atom_e = data.compute_energy_shift() - - t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") - t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") - t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") - t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") - t_atom_ener = tf.placeholder( - GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" - ) - t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") - t_type = tf.placeholder(tf.int32, [None], name="i_type") - t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") - t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") - t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") - is_training = tf.placeholder(tf.bool) - t_fparam = None - - model_pred = model.build( - t_coord, - t_type, - t_natoms, - t_box, - t_mesh, - t_fparam, - suffix="_layer_name", - reuse=False, - ) - - e_energy1 = model_pred["water_ener"]["energy"] - e_force1 = model_pred["water_ener"]["force"] - e_virial1 = model_pred["water_ener"]["virial"] - e_energy2 = model_pred["water_ener2"]["energy"] - e_force2 = model_pred["water_ener2"]["force"] - e_virial2 = model_pred["water_ener2"]["virial"] - feed_dict_test = { - t_prop_c: test_data["prop_c"], - t_energy: test_data["energy"][:numb_test], - t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), - t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), - t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), - t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), - t_box: test_data["box"][:numb_test, :], - t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), - t_natoms: test_data["natoms_vec"], - t_mesh: test_data["default_mesh"], - is_training: False, - } - - with self.cached_session() as sess: - sess.run(tf.global_variables_initializer()) - [e1, f1, v1, e2, f2, v2] = sess.run( - [e_energy1, e_force1, e_virial1, e_energy2, e_force2, e_virial2], - feed_dict=feed_dict_test, - ) - np.testing.assert_allclose(e1, e2, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(f1, f2, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(v1, v2, rtol=1e-5, atol=1e-5) diff --git a/source/tests/tf/test_model_multi.py b/source/tests/tf/test_model_multi.py deleted file mode 100644 index 4b0605b8a7..0000000000 --- a/source/tests/tf/test_model_multi.py +++ /dev/null @@ -1,264 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import numpy as np - -from deepmd.tf.common import ( - j_must_have, -) -from deepmd.tf.descriptor import ( - DescrptSeA, -) -from deepmd.tf.env import ( - tf, -) -from deepmd.tf.fit import ( - DipoleFittingSeA, - EnerFitting, -) -from deepmd.tf.model import ( - MultiModel, -) - -from .common import ( - DataSystem, - del_data, - finite_difference, - gen_data, - j_loader, - strerch_box, -) - -GLOBAL_ENER_FLOAT_PRECISION = tf.float64 -GLOBAL_TF_FLOAT_PRECISION = tf.float64 -GLOBAL_NP_FLOAT_PRECISION = np.float64 - - -class TestModel(tf.test.TestCase): - def setUp(self): - gen_data() - - def tearDown(self): - del_data() - - def test_model(self): - jfile = "water_multi.json" - jdata = j_loader(jfile) - - systems = j_must_have(jdata, "systems") - set_pfx = "set" - batch_size = j_must_have(jdata, "batch_size") - test_size = j_must_have(jdata, "numb_test") - batch_size = 1 - test_size = 1 - stop_batch = j_must_have(jdata, "stop_batch") - rcut = j_must_have(jdata["model"]["descriptor"], "rcut") - - data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None) - - test_data = data.get_test() - numb_test = 1 - - jdata["model"]["descriptor"].pop("type", None) - jdata["model"]["descriptor"]["multi_task"] = True - descrpt = DescrptSeA(**jdata["model"]["descriptor"], uniform_seed=True) - fitting_dict = {} - fitting_type_dict = {} - for fitting_key in jdata["model"]["fitting_net_dict"]: - item_fitting_param = jdata["model"]["fitting_net_dict"][fitting_key] - item_fitting_type = item_fitting_param.get("type", "ener") - fitting_type_dict[fitting_key] = item_fitting_type - item_fitting_param.pop("type", None) - item_fitting_param.pop("fit_diag", None) - item_fitting_param["descrpt"] = descrpt - item_fitting_param["embedding_width"] = descrpt.get_dim_rot_mat_1() - item_fitting_param["ntypes"] = descrpt.get_ntypes() - item_fitting_param["dim_descrpt"] = descrpt.get_dim_out() - if item_fitting_type == "ener": - fitting_dict[fitting_key] = EnerFitting( - **item_fitting_param, uniform_seed=True - ) - elif item_fitting_type == "dipole": - fitting_dict[fitting_key] = DipoleFittingSeA( - **item_fitting_param, uniform_seed=True - ) - else: - RuntimeError("Test should not be here!") - model = MultiModel(descrpt, fitting_dict, fitting_type_dict) - - input_data = { - "coord": [test_data["coord"]], - "box": [test_data["box"]], - "type": [test_data["type"]], - "natoms_vec": [test_data["natoms_vec"]], - "default_mesh": [test_data["default_mesh"]], - } - for fitting_key in jdata["model"]["fitting_net_dict"]: - model._compute_input_stat(input_data, fitting_key=fitting_key) - model.descrpt.merge_input_stats(model.descrpt.stat_dict) - model.descrpt.bias_atom_e = data.compute_energy_shift() - - t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c") - t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name="t_energy") - t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_force") - t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="t_virial") - t_atom_ener = tf.placeholder( - GLOBAL_TF_FLOAT_PRECISION, [None], name="t_atom_ener" - ) - t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord") - t_type = tf.placeholder(tf.int32, [None], name="i_type") - t_natoms = tf.placeholder(tf.int32, [model.ntypes + 2], name="i_natoms") - t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box") - t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh") - is_training = tf.placeholder(tf.bool) - t_fparam = None - - model_pred = model.build( - t_coord, - t_type, - t_natoms, - t_box, - t_mesh, - t_fparam, - suffix="multi", - reuse=False, - ) - e_energy = model_pred["water_ener"]["energy"] - e_force = model_pred["water_ener"]["force"] - e_virial = model_pred["water_ener"]["virial"] - e_atom_ener = model_pred["water_ener"]["atom_ener"] - - d_dipole = model_pred["water_dipole"]["dipole"] - d_gdipole = model_pred["water_dipole"]["global_dipole"] - d_force = model_pred["water_dipole"]["force"] - d_virial = model_pred["water_dipole"]["virial"] - d_atom_virial = model_pred["water_dipole"]["atom_virial"] - - feed_dict_test = { - t_prop_c: test_data["prop_c"], - t_energy: test_data["energy"][:numb_test], - t_force: np.reshape(test_data["force"][:numb_test, :], [-1]), - t_virial: np.reshape(test_data["virial"][:numb_test, :], [-1]), - t_atom_ener: np.reshape(test_data["atom_ener"][:numb_test, :], [-1]), - t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]), - t_box: test_data["box"][:numb_test, :], - t_type: np.reshape(test_data["type"][:numb_test, :], [-1]), - t_natoms: test_data["natoms_vec"], - t_mesh: test_data["default_mesh"], - is_training: False, - } - sess = self.cached_session().__enter__() - - # test water energy - sess.run(tf.global_variables_initializer()) - [e, f, v] = sess.run([e_energy, e_force, e_virial], feed_dict=feed_dict_test) - e = e.reshape([-1]) - f = f.reshape([-1]) - v = v.reshape([-1]) - refe = [6.135449167779321300e01] - reff = [ - 7.799691562262310585e-02, - 9.423098804815030483e-02, - 3.790560997388224204e-03, - 1.432522403799846578e-01, - 1.148392791403983204e-01, - -1.321871172563671148e-02, - -7.318966526325138000e-02, - 6.516069212737778116e-02, - 5.406418483320515412e-04, - 5.870713761026503247e-02, - -1.605402669549013672e-01, - -5.089516979826595386e-03, - -2.554593467731766654e-01, - 3.092063507347833987e-02, - 1.510355029451411479e-02, - 4.869271842355533952e-02, - -1.446113274345035005e-01, - -1.126524434771078789e-03, - ] - refv = [ - -6.076776685178300053e-01, - 1.103174323630009418e-01, - 1.984250991380156690e-02, - 1.103174323630009557e-01, - -3.319759402259439551e-01, - -6.007404107650986258e-03, - 1.984250991380157036e-02, - -6.007404107650981921e-03, - -1.200076017439753642e-03, - ] - refe = np.reshape(refe, [-1]) - reff = np.reshape(reff, [-1]) - refv = np.reshape(refv, [-1]) - - places = 10 - np.testing.assert_almost_equal(e, refe, places) - np.testing.assert_almost_equal(f, reff, places) - np.testing.assert_almost_equal(v, refv, places) - - # test water dipole - [p, gp] = sess.run([d_dipole, d_gdipole], feed_dict=feed_dict_test) - p = p.reshape([-1]) - refp = [ - 1.616802262298876514e01, - 9.809535439521079425e00, - 3.572312180768947854e-01, - 1.336308874095981203e00, - 1.057908563208963848e01, - -5.999602350098874881e-01, - ] - places = 10 - np.testing.assert_almost_equal(p, refp, places) - gp = gp.reshape([-1]) - refgp = np.array(refp).reshape(-1, 3).sum(0) - places = 9 - np.testing.assert_almost_equal(gp, refgp, places) - - # test water dipole : make sure only one frame is used - feed_dict_single = { - t_prop_c: test_data["prop_c"], - t_coord: np.reshape(test_data["coord"][:1, :], [-1]), - t_box: test_data["box"][:1, :], - t_type: np.reshape(test_data["type"][:1, :], [-1]), - t_natoms: test_data["natoms_vec"], - t_mesh: test_data["default_mesh"], - is_training: False, - } - - [pf, pv, pav] = sess.run( - [d_force, d_virial, d_atom_virial], feed_dict=feed_dict_single - ) - pf, pv = pf.reshape(-1), pv.reshape(-1) - spv = pav.reshape(1, 3, -1, 9).sum(2).reshape(-1) - - base_dict = feed_dict_single.copy() - coord0 = base_dict.pop(t_coord) - box0 = base_dict.pop(t_box) - - fdf = -finite_difference( - lambda coord: sess.run( - d_gdipole, feed_dict={**base_dict, t_coord: coord, t_box: box0} - ).reshape(-1), - test_data["coord"][:numb_test, :].reshape([-1]), - ).reshape(-1) - fdv = -( - finite_difference( - lambda box: sess.run( - d_gdipole, - feed_dict={ - **base_dict, - t_coord: strerch_box(coord0, box0, box), - t_box: box, - }, - ).reshape(-1), - test_data["box"][:numb_test, :], - ) - .reshape([-1, 3, 3]) - .transpose(0, 2, 1) - @ box0.reshape(3, 3) - ).reshape(-1) - - delta = 1e-5 - np.testing.assert_allclose(pf, fdf, delta) - np.testing.assert_allclose(pv, fdv, delta) - # make sure atomic virial sum to virial - places = 10 - np.testing.assert_almost_equal(pv, spv, places) diff --git a/source/tests/tf/test_nvnmd_entrypoints.py b/source/tests/tf/test_nvnmd_entrypoints.py index bf4b2288c0..17ad62b4bc 100644 --- a/source/tests/tf/test_nvnmd_entrypoints.py +++ b/source/tests/tf/test_nvnmd_entrypoints.py @@ -394,6 +394,9 @@ def test_mapt_cnn_v0(self): @pytest.mark.run(order=1) def test_model_qnn_v0(self): + # without calling test_mapt_cnn_v0, this test will fail when running individually + self.test_mapt_cnn_v0() + tf.reset_default_graph() # open NVNMD jdata_cf = jdata_deepmd_input_v0["nvnmd"] @@ -703,6 +706,9 @@ def test_mapt_cnn_v1(self): @pytest.mark.run(order=1) def test_model_qnn_v1(self): + # without calling test_mapt_cnn_v1, this test will fail when running individually + self.test_mapt_cnn_v1() + tf.reset_default_graph() # open NVNMD jdata_cf = jdata_deepmd_input_v1["nvnmd"] diff --git a/source/tests/tf/water_layer_name.json b/source/tests/tf/water_layer_name.json deleted file mode 100644 index 06b9b981ec..0000000000 --- a/source/tests/tf/water_layer_name.json +++ /dev/null @@ -1,105 +0,0 @@ -{ - "_comment1": "layer_name", - "model": { - "descriptor": { - "type": "se_a", - "sel": [ - 46, - 92 - ], - "rcut_smth": 5.80, - "rcut": 6.00, - "neuron": [ - 8, - 16, - 32 - ], - "resnet_dt": false, - "axis_neuron": 16, - "seed": 1 - }, - "fitting_net_dict": { - "water_ener": { - "type": "ener", - "neuron": [ - 32, - 32, - 32 - ], - "resnet_dt": true, - "layer_name": [ - "layer0", - "layer1", - "layer2", - "final_layer" - ], - "seed": 1 - }, - "water_ener2": { - "type": "ener", - "neuron": [ - 32, - 32, - 32 - ], - "resnet_dt": true, - "layer_name": [ - "layer0", - "layer1", - "layer2", - "final_layer" - ], - "seed": 2 - } - } - }, - "learning_rate": { - "type": "exp", - "start_lr": 0.001, - "decay_steps": 5000, - "decay_rate": 0.95, - "_comment2": "that's all" - }, - - "loss_dict": { - "water_ener": { - "type": "ener", - "start_pref_e": 0.02, - "limit_pref_e": 1, - "start_pref_f": 1000, - "limit_pref_f": 1, - "start_pref_v": 0, - "limit_pref_v": 0 - }, - "water_ener2": { - "type": "ener", - "start_pref_e": 0.02, - "limit_pref_e": 1, - "start_pref_f": 1000, - "limit_pref_f": 1, - "start_pref_v": 0, - "limit_pref_v": 0 - } - }, - - "_comment3": " traing controls", - "systems": [ - "system" - ], - "stop_batch": 1000000, - "batch_size": 1, - "seed": 1, - - "disp_file": "lcurve.out", - "disp_freq": 100, - "numb_test": 1, - "save_freq": 1000, - "save_ckpt": "model.ckpt", - "load_ckpt": "model.ckpt", - "disp_training": true, - "time_training": true, - "profiling": false, - "profiling_file": "timeline.json", - - "_comment4": "that's all" -} diff --git a/source/tests/tf/water_multi.json b/source/tests/tf/water_multi.json deleted file mode 100644 index 44522ed362..0000000000 --- a/source/tests/tf/water_multi.json +++ /dev/null @@ -1,103 +0,0 @@ -{ - "_comment1": " model parameters", - "model": { - "descriptor": { - "type": "se_a", - "sel": [ - 46, - 92 - ], - "rcut_smth": 5.80, - "rcut": 6.00, - "neuron": [ - 25, - 50, - 100 - ], - "resnet_dt": false, - "axis_neuron": 16, - "seed": 1 - }, - "fitting_net_dict": { - "water_ener": { - "type": "ener", - "neuron": [ - 240, - 240, - 240 - ], - "resnet_dt": true, - "seed": 1 - }, - "water_dipole": { - "type": "dipole", - "sel_type": [ - 0 - ], - "fit_diag": false, - "neuron": [ - 100, - 100, - 100 - ], - "resnet_dt": true, - "seed": 1 - } - } - }, - "learning_rate_dict": - { - "water_ener": { - "type": "exp", - "start_lr": 0.001, - "decay_steps": 5000, - "decay_rate": 0.95, - "_comment2": "that's all" - }, - "water_dipole": { - "type": "exp", - "start_lr": 0.001, - "decay_steps": 5000, - "decay_rate": 0.95, - "_comment3": "that's all" - } - }, - - "loss_dict": { - "water_ener": { - "type": "ener", - "start_pref_e": 0.02, - "limit_pref_e": 1, - "start_pref_f": 1000, - "limit_pref_f": 1, - "start_pref_v": 0, - "limit_pref_v": 0 - }, - "water_dipole": { - "type": "tensor", - "pref": 1.0, - "pref_atomic": 1.0 - } - }, - - "_comment4": " traing controls", - "systems": [ - "system" - ], - "stop_batch": 1000000, - "batch_size": 1, - "seed": 1, - - "disp_file": "lcurve.out", - "disp_freq": 100, - "numb_test": 1, - "save_freq": 1000, - "save_ckpt": "model.ckpt", - "load_ckpt": "model.ckpt", - "disp_training": true, - "time_training": true, - "profiling": false, - "profiling_file": "timeline.json", - - "_comment5": "that's all" -}