From 7bf16194b63de843a0c02c6456680009c712e907 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 5 Oct 2023 06:53:06 -0400 Subject: [PATCH] refactor update_sel (#2901) The current `update_sel` has too many complex conditionals. This PR rewrites them in the `Model` and `Descriptor` classes, which makes it easier to maintain the existing classes and implement new classes. --------- Signed-off-by: Jinzhe Zeng --- deepmd/descriptor/descriptor.py | 36 ++++++++--- deepmd/descriptor/hybrid.py | 18 ++++++ deepmd/descriptor/loc_frame.py | 13 ++++ deepmd/descriptor/se.py | 19 ++++++ deepmd/descriptor/se_a_ef.py | 5 +- deepmd/descriptor/se_atten.py | 18 ++++++ deepmd/entrypoints/train.py | 41 ++++-------- deepmd/model/frozen.py | 14 +++++ deepmd/model/linear.py | 18 ++++++ deepmd/model/model.py | 106 ++++++++++++++++++++++++-------- deepmd/model/multi.py | 17 +++++ deepmd/model/pairwise_dprc.py | 20 ++++++ source/tests/test_train.py | 44 +++++++++++++ 13 files changed, 305 insertions(+), 64 deletions(-) diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index c885e73145..bd731004cb 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -66,16 +66,20 @@ class SomeDescript(Descriptor): """ return Descriptor.__plugins.register(key) + @classmethod + def get_class_by_input(cls, input: dict): + try: + descrpt_type = input["type"] + except KeyError: + raise KeyError("the type of descriptor should be set by `type`") + if descrpt_type in Descriptor.__plugins.plugins: + return Descriptor.__plugins.plugins[descrpt_type] + else: + raise RuntimeError("Unknown descriptor type: " + descrpt_type) + def __new__(cls, *args, **kwargs): if cls is Descriptor: - try: - descrpt_type = kwargs["type"] - except KeyError: - raise KeyError("the type of descriptor should be set by `type`") - if descrpt_type in Descriptor.__plugins.plugins: - cls = Descriptor.__plugins.plugins[descrpt_type] - else: - raise RuntimeError("Unknown descriptor type: " + descrpt_type) + cls = cls.get_class_by_input(kwargs) return super().__new__(cls) @abstractmethod @@ -489,3 +493,19 @@ def build_type_exclude_mask( def explicit_ntypes(self) -> bool: """Explicit ntypes with type embedding.""" return False + + @classmethod + @abstractmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + # call subprocess + cls = cls.get_class_by_input(local_jdata) + return cls.update_sel(global_jdata, local_jdata) diff --git a/deepmd/descriptor/hybrid.py b/deepmd/descriptor/hybrid.py index 26736cd653..5ee5ec884b 100644 --- a/deepmd/descriptor/hybrid.py +++ b/deepmd/descriptor/hybrid.py @@ -416,3 +416,21 @@ def pass_tensors_from_frz_model( def explicit_ntypes(self) -> bool: """Explicit ntypes with type embedding.""" return any(ii.explicit_ntypes for ii in self.descrpt_list) + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["list"] = [ + Descriptor.update_sel(global_jdata, sub_jdata) + for sub_jdata in local_jdata["list"] + ] + return local_jdata_cpy diff --git a/deepmd/descriptor/loc_frame.py b/deepmd/descriptor/loc_frame.py index 409e59f5e7..0765be55f8 100644 --- a/deepmd/descriptor/loc_frame.py +++ b/deepmd/descriptor/loc_frame.py @@ -430,3 +430,16 @@ def init_variables( self.dstd = get_tensor_by_name_from_graph( graph, "descrpt_attr%s/t_std" % suffix ) + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + return local_jdata diff --git a/deepmd/descriptor/se.py b/deepmd/descriptor/se.py index 3a1ec41ddb..598f6f9ff8 100644 --- a/deepmd/descriptor/se.py +++ b/deepmd/descriptor/se.py @@ -141,3 +141,22 @@ def init_variables( def precision(self) -> tf.DType: """Precision of filter network.""" return self.filter_precision + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + from deepmd.entrypoints.train import ( + update_one_sel, + ) + + # default behavior is to update sel which is a list + local_jdata_cpy = local_jdata.copy() + return update_one_sel(global_jdata, local_jdata_cpy, False) diff --git a/deepmd/descriptor/se_a_ef.py b/deepmd/descriptor/se_a_ef.py index fb886483f6..32a62b48f3 100644 --- a/deepmd/descriptor/se_a_ef.py +++ b/deepmd/descriptor/se_a_ef.py @@ -24,13 +24,16 @@ from .descriptor import ( Descriptor, ) +from .se import ( + DescrptSe, +) from .se_a import ( DescrptSeA, ) @Descriptor.register("se_a_ef") -class DescrptSeAEf(Descriptor): +class DescrptSeAEf(DescrptSe): r"""Smooth edition descriptor with Ef. Parameters diff --git a/deepmd/descriptor/se_atten.py b/deepmd/descriptor/se_atten.py index b0c65108e5..ce280f7ee8 100644 --- a/deepmd/descriptor/se_atten.py +++ b/deepmd/descriptor/se_atten.py @@ -1404,3 +1404,21 @@ def build_type_exclude_mask( def explicit_ntypes(self) -> bool: """Explicit ntypes with type embedding.""" return True + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + from deepmd.entrypoints.train import ( + update_one_sel, + ) + + local_jdata_cpy = local_jdata.copy() + return update_one_sel(global_jdata, local_jdata_cpy, True) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 716ff482a3..bd7a2ac7ec 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -27,6 +27,9 @@ from deepmd.infer.data_modifier import ( DipoleChargeModifier, ) +from deepmd.model.model import ( + Model, +) from deepmd.train.run_options import ( BUILD, CITATION, @@ -374,7 +377,10 @@ def get_type_map(jdata): def get_nbor_stat(jdata, rcut, one_type: bool = False): - max_rcut = get_rcut(jdata) + # it seems that DeepmdDataSystem does not need rcut + # it's not clear why there is an argument... + # max_rcut = get_rcut(jdata) + max_rcut = rcut type_map = get_type_map(jdata) if type_map and len(type_map) == 0: @@ -472,18 +478,12 @@ def wrap_up_4(xx): return 4 * ((int(xx) + 3) // 4) -def update_one_sel(jdata, descriptor): - if descriptor["type"] == "loc_frame": - return descriptor +def update_one_sel(jdata, descriptor, one_type: bool = False): rcut = descriptor["rcut"] tmp_sel = get_sel( jdata, rcut, - one_type=descriptor["type"] - in ( - "se_atten", - "se_atten_v2", - ), + one_type=one_type, ) sel = descriptor["sel"] if isinstance(sel, int): @@ -503,10 +503,7 @@ def update_one_sel(jdata, descriptor): "not less than %d, but you set it to %d. The accuracy" " of your model may get worse." % (ii, tt, dd) ) - if descriptor["type"] in ( - "se_atten", - "se_atten_v2", - ): + if one_type: descriptor["sel"] = sel = sum(sel) return descriptor @@ -515,18 +512,6 @@ def update_sel(jdata): log.info( "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" ) - if jdata["model"].get("type") == "pairwise_dprc": - # do not update sel; only find min distance - rcut = get_rcut(jdata) - get_min_nbor_dist(jdata, rcut) - return jdata - elif jdata["model"].get("type") in ("linear_ener", "frozen"): - return jdata - descrpt_data = jdata["model"]["descriptor"] - if descrpt_data["type"] == "hybrid": - for ii in range(len(descrpt_data["list"])): - descrpt_data["list"][ii] = update_one_sel(jdata, descrpt_data["list"][ii]) - else: - descrpt_data = update_one_sel(jdata, descrpt_data) - jdata["model"]["descriptor"] = descrpt_data - return jdata + jdata_cpy = jdata.copy() + jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"]) + return jdata_cpy diff --git a/deepmd/model/frozen.py b/deepmd/model/frozen.py index 972acb9185..38f342ebec 100644 --- a/deepmd/model/frozen.py +++ b/deepmd/model/frozen.py @@ -193,3 +193,17 @@ def enable_compression(self, suffix: str = "") -> None: def get_type_map(self) -> list: """Get the type map.""" return self.model.get_type_map() + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + # we don't know how to compress it, so no neighbor statistics here + return local_jdata diff --git a/deepmd/model/linear.py b/deepmd/model/linear.py index 799642ce33..7c527fe9dc 100644 --- a/deepmd/model/linear.py +++ b/deepmd/model/linear.py @@ -128,6 +128,24 @@ def get_type_map(self) -> list: """Get the type map.""" return self.models[0].get_type_map() + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["models"] = [ + Model.update_sel(global_jdata, sub_jdata) + for sub_jdata in local_jdata["models"] + ] + return local_jdata_cpy + class LinearEnergyModel(LinearModel): """Linear energy model make linear combinations of several existing energy models.""" diff --git a/deepmd/model/model.py b/deepmd/model/model.py index 9ae5eacf4f..cef2067609 100644 --- a/deepmd/model/model.py +++ b/deepmd/model/model.py @@ -78,36 +78,47 @@ class Model(ABC): Compression information for internal use """ + @classmethod + def get_class_by_input(cls, input: dict): + """Get the class by input data. + + Parameters + ---------- + input : dict + The input data + """ + # infer model type by fitting_type + from deepmd.model.frozen import ( + FrozenModel, + ) + from deepmd.model.linear import ( + LinearEnergyModel, + ) + from deepmd.model.multi import ( + MultiModel, + ) + from deepmd.model.pairwise_dprc import ( + PairwiseDPRc, + ) + + model_type = input.get("type", "standard") + if model_type == "standard": + return StandardModel + elif model_type == "multi": + return MultiModel + elif model_type == "pairwise_dprc": + return PairwiseDPRc + elif model_type == "frozen": + return FrozenModel + elif model_type == "linear_ener": + return LinearEnergyModel + else: + raise ValueError(f"unknown model type: {model_type}") + def __new__(cls, *args, **kwargs): if cls is Model: # init model - # infer model type by fitting_type - from deepmd.model.frozen import ( - FrozenModel, - ) - from deepmd.model.linear import ( - LinearEnergyModel, - ) - from deepmd.model.multi import ( - MultiModel, - ) - from deepmd.model.pairwise_dprc import ( - PairwiseDPRc, - ) - - model_type = kwargs.get("type", "standard") - if model_type == "standard": - cls = StandardModel - elif model_type == "multi": - cls = MultiModel - elif model_type == "pairwise_dprc": - cls = PairwiseDPRc - elif model_type == "frozen": - cls = FrozenModel - elif model_type == "linear_ener": - cls = LinearEnergyModel - else: - raise ValueError(f"unknown model type: {model_type}") + cls = cls.get_class_by_input(kwargs) return cls.__new__(cls, *args, **kwargs) return super().__new__(cls) @@ -471,6 +482,30 @@ def get_feed_dict( feed_dict["t_aparam:0"] = kwargs["aparam"] return feed_dict + @classmethod + @abstractmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: + """Update the selection and perform neighbor statistics. + + Notes + ----- + Do not modify the input data without copying it. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + """ + cls = cls.get_class_by_input(local_jdata) + return cls.update_sel(global_jdata, local_jdata) + class StandardModel(Model): """Standard model, which must contain a descriptor and a fitting. @@ -613,3 +648,20 @@ def get_rcut(self) -> float: def get_ntypes(self) -> int: """Get the number of types.""" return self.ntypes + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["descriptor"] = Descriptor.update_sel( + global_jdata, local_jdata["descriptor"] + ) + return local_jdata_cpy diff --git a/deepmd/model/multi.py b/deepmd/model/multi.py index b0aa11a109..c224cfdb21 100644 --- a/deepmd/model/multi.py +++ b/deepmd/model/multi.py @@ -645,3 +645,20 @@ def get_loss(self, loss: dict, lr: dict) -> Dict[str, Loss]: loss_param, lr[fitting_key] ) return loss_dict + + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + local_jdata_cpy = local_jdata.copy() + local_jdata_cpy["descriptor"] = Descriptor.update_sel( + global_jdata, local_jdata["descriptor"] + ) + return local_jdata_cpy diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index 8f46ec239d..80aea92bb1 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -395,6 +395,26 @@ def get_feed_dict( } return feed_dict + @classmethod + def update_sel(cls, global_jdata: dict, local_jdata: dict): + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + global_jdata : dict + The global data, containing the training section + local_jdata : dict + The local data refer to the current class + """ + from deepmd.entrypoints.train import ( + get_min_nbor_dist, + ) + + # do not update sel; only find min distance + # rcut is not important here + get_min_nbor_dist(global_jdata, 6.0) + return local_jdata + def gather_placeholder( params: tf.Tensor, indices: tf.Tensor, placeholder: float = 0.0, **kwargs diff --git a/source/tests/test_train.py b/source/tests/test_train.py index 3d190ba716..145457260f 100644 --- a/source/tests/test_train.py +++ b/source/tests/test_train.py @@ -174,6 +174,50 @@ def test_skip_loc_frame(self): jdata = update_sel(jdata) self.assertEqual(jdata, expected_out) + def test_skip_frozen(self): + jdata = { + "model": { + "type": "frozen", + } + } + expected_out = jdata.copy() + jdata = update_sel(jdata) + self.assertEqual(jdata, expected_out) + + def test_skip_linear_frozen(self): + jdata = { + "model": { + "type": "linear_ener", + "models": [ + {"type": "frozen"}, + {"type": "frozen"}, + {"type": "frozen"}, + {"type": "frozen"}, + ], + } + } + expected_out = jdata.copy() + jdata = update_sel(jdata) + self.assertEqual(jdata, expected_out) + + @patch("deepmd.entrypoints.train.get_min_nbor_dist") + def test_pairwise_dprc(self, sel_mock): + sel_mock.return_value = 0.5 + jdata = { + "model": { + "type": "pairwise_dprc", + "models": [ + {"type": "frozen"}, + {"type": "frozen"}, + {"type": "frozen"}, + {"type": "frozen"}, + ], + } + } + expected_out = jdata.copy() + jdata = update_sel(jdata) + self.assertEqual(jdata, expected_out) + def test_wrap_up_4(self): self.assertEqual(wrap_up_4(12), 3 * 4) self.assertEqual(wrap_up_4(13), 4 * 4)