Skip to content

Commit

Permalink
refactor update_sel (#2901)
Browse files Browse the repository at this point in the history
The current `update_sel` has too many complex conditionals. This PR
rewrites them in the `Model` and `Descriptor` classes, which makes it
easier to maintain the existing classes and implement new classes.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 5, 2023
1 parent 6f4fc02 commit 7bf1619
Show file tree
Hide file tree
Showing 13 changed files with 305 additions and 64 deletions.
36 changes: 28 additions & 8 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,20 @@ class SomeDescript(Descriptor):
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_input(cls, input: dict):
try:
descrpt_type = input["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)
cls = cls.get_class_by_input(kwargs)
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -489,3 +493,19 @@ def build_type_exclude_mask(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return False

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_input(local_jdata)
return cls.update_sel(global_jdata, local_jdata)
18 changes: 18 additions & 0 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,21 @@ def pass_tensors_from_frz_model(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return any(ii.explicit_ntypes for ii in self.descrpt_list)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [
Descriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy
13 changes: 13 additions & 0 deletions deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,16 @@ def init_variables(
self.dstd = get_tensor_by_name_from_graph(
graph, "descrpt_attr%s/t_std" % suffix
)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
return local_jdata
19 changes: 19 additions & 0 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,22 @@ def init_variables(
def precision(self) -> tf.DType:
"""Precision of filter network."""
return self.filter_precision

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
from deepmd.entrypoints.train import (
update_one_sel,
)

# default behavior is to update sel which is a list
local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, False)
5 changes: 4 additions & 1 deletion deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
from .descriptor import (
Descriptor,
)
from .se import (
DescrptSe,
)
from .se_a import (
DescrptSeA,
)


@Descriptor.register("se_a_ef")
class DescrptSeAEf(Descriptor):
class DescrptSeAEf(DescrptSe):
r"""Smooth edition descriptor with Ef.
Parameters
Expand Down
18 changes: 18 additions & 0 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,3 +1404,21 @@ def build_type_exclude_mask(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return True

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
from deepmd.entrypoints.train import (
update_one_sel,
)

local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, True)
41 changes: 13 additions & 28 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.model.model import (
Model,
)
from deepmd.train.run_options import (
BUILD,
CITATION,
Expand Down Expand Up @@ -374,7 +377,10 @@ def get_type_map(jdata):


def get_nbor_stat(jdata, rcut, one_type: bool = False):
max_rcut = get_rcut(jdata)
# it seems that DeepmdDataSystem does not need rcut
# it's not clear why there is an argument...
# max_rcut = get_rcut(jdata)
max_rcut = rcut
type_map = get_type_map(jdata)

if type_map and len(type_map) == 0:
Expand Down Expand Up @@ -472,18 +478,12 @@ def wrap_up_4(xx):
return 4 * ((int(xx) + 3) // 4)


def update_one_sel(jdata, descriptor):
if descriptor["type"] == "loc_frame":
return descriptor
def update_one_sel(jdata, descriptor, one_type: bool = False):
rcut = descriptor["rcut"]
tmp_sel = get_sel(
jdata,
rcut,
one_type=descriptor["type"]
in (
"se_atten",
"se_atten_v2",
),
one_type=one_type,
)
sel = descriptor["sel"]
if isinstance(sel, int):
Expand All @@ -503,10 +503,7 @@ def update_one_sel(jdata, descriptor):
"not less than %d, but you set it to %d. The accuracy"
" of your model may get worse." % (ii, tt, dd)
)
if descriptor["type"] in (
"se_atten",
"se_atten_v2",
):
if one_type:
descriptor["sel"] = sel = sum(sel)
return descriptor

Expand All @@ -515,18 +512,6 @@ def update_sel(jdata):
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
if jdata["model"].get("type") == "pairwise_dprc":
# do not update sel; only find min distance
rcut = get_rcut(jdata)
get_min_nbor_dist(jdata, rcut)
return jdata
elif jdata["model"].get("type") in ("linear_ener", "frozen"):
return jdata
descrpt_data = jdata["model"]["descriptor"]
if descrpt_data["type"] == "hybrid":
for ii in range(len(descrpt_data["list"])):
descrpt_data["list"][ii] = update_one_sel(jdata, descrpt_data["list"][ii])
else:
descrpt_data = update_one_sel(jdata, descrpt_data)
jdata["model"]["descriptor"] = descrpt_data
return jdata
jdata_cpy = jdata.copy()
jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"])
return jdata_cpy
14 changes: 14 additions & 0 deletions deepmd/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,17 @@ def enable_compression(self, suffix: str = "") -> None:
def get_type_map(self) -> list:
"""Get the type map."""
return self.model.get_type_map()

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# we don't know how to compress it, so no neighbor statistics here
return local_jdata
18 changes: 18 additions & 0 deletions deepmd/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,24 @@ def get_type_map(self) -> list:
"""Get the type map."""
return self.models[0].get_type_map()

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["models"] = [
Model.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["models"]
]
return local_jdata_cpy


class LinearEnergyModel(LinearModel):
"""Linear energy model make linear combinations of several existing energy models."""
Expand Down
106 changes: 79 additions & 27 deletions deepmd/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,47 @@ class Model(ABC):
Compression information for internal use
"""

@classmethod
def get_class_by_input(cls, input: dict):
"""Get the class by input data.
Parameters
----------
input : dict
The input data
"""
# infer model type by fitting_type
from deepmd.model.frozen import (
FrozenModel,
)
from deepmd.model.linear import (
LinearEnergyModel,
)
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)

model_type = input.get("type", "standard")
if model_type == "standard":
return StandardModel
elif model_type == "multi":
return MultiModel
elif model_type == "pairwise_dprc":
return PairwiseDPRc
elif model_type == "frozen":
return FrozenModel
elif model_type == "linear_ener":
return LinearEnergyModel
else:
raise ValueError(f"unknown model type: {model_type}")

def __new__(cls, *args, **kwargs):
if cls is Model:
# init model
# infer model type by fitting_type
from deepmd.model.frozen import (
FrozenModel,
)
from deepmd.model.linear import (
LinearEnergyModel,
)
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)

model_type = kwargs.get("type", "standard")
if model_type == "standard":
cls = StandardModel
elif model_type == "multi":
cls = MultiModel
elif model_type == "pairwise_dprc":
cls = PairwiseDPRc
elif model_type == "frozen":
cls = FrozenModel
elif model_type == "linear_ener":
cls = LinearEnergyModel
else:
raise ValueError(f"unknown model type: {model_type}")
cls = cls.get_class_by_input(kwargs)
return cls.__new__(cls, *args, **kwargs)
return super().__new__(cls)

Expand Down Expand Up @@ -471,6 +482,30 @@ def get_feed_dict(
feed_dict["t_aparam:0"] = kwargs["aparam"]
return feed_dict

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
"""Update the selection and perform neighbor statistics.
Notes
-----
Do not modify the input data without copying it.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
"""
cls = cls.get_class_by_input(local_jdata)
return cls.update_sel(global_jdata, local_jdata)


class StandardModel(Model):
"""Standard model, which must contain a descriptor and a fitting.
Expand Down Expand Up @@ -613,3 +648,20 @@ def get_rcut(self) -> float:
def get_ntypes(self) -> int:
"""Get the number of types."""
return self.ntypes

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = Descriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
Loading

0 comments on commit 7bf1619

Please sign in to comment.