Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor update_sel #2901

Merged
merged 4 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,20 @@ class SomeDescript(Descriptor):
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_input(cls, input: dict):
try:
descrpt_type = input["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)
cls = cls.get_class_by_input(kwargs)
return super().__new__(cls)

@abstractmethod
Expand Down Expand Up @@ -489,3 +493,19 @@ def build_type_exclude_mask(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return False

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# call subprocess
cls = cls.get_class_by_input(local_jdata)
return cls.update_sel(global_jdata, local_jdata)
18 changes: 18 additions & 0 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,21 @@ def pass_tensors_from_frz_model(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return any(ii.explicit_ntypes for ii in self.descrpt_list)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [
Descriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy
13 changes: 13 additions & 0 deletions deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,16 @@ def init_variables(
self.dstd = get_tensor_by_name_from_graph(
graph, "descrpt_attr%s/t_std" % suffix
)

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
return local_jdata
19 changes: 19 additions & 0 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,22 @@ def init_variables(
def precision(self) -> tf.DType:
"""Precision of filter network."""
return self.filter_precision

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
from deepmd.entrypoints.train import (
update_one_sel,
)

# default behavior is to update sel which is a list
local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, False)
5 changes: 4 additions & 1 deletion deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
from .descriptor import (
Descriptor,
)
from .se import (
DescrptSe,
)
from .se_a import (
DescrptSeA,
)


@Descriptor.register("se_a_ef")
class DescrptSeAEf(Descriptor):
class DescrptSeAEf(DescrptSe):
r"""Smooth edition descriptor with Ef.

Parameters
Expand Down
18 changes: 18 additions & 0 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,3 +1404,21 @@ def build_type_exclude_mask(
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
return True

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
from deepmd.entrypoints.train import (
update_one_sel,
)

local_jdata_cpy = local_jdata.copy()
return update_one_sel(global_jdata, local_jdata_cpy, True)
41 changes: 13 additions & 28 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from deepmd.infer.data_modifier import (
DipoleChargeModifier,
)
from deepmd.model.model import (
Model,
)
from deepmd.train.run_options import (
BUILD,
CITATION,
Expand Down Expand Up @@ -374,7 +377,10 @@ def get_type_map(jdata):


def get_nbor_stat(jdata, rcut, one_type: bool = False):
max_rcut = get_rcut(jdata)
# it seems that DeepmdDataSystem does not need rcut
# it's not clear why there is an argument...
# max_rcut = get_rcut(jdata)
max_rcut = rcut
type_map = get_type_map(jdata)

if type_map and len(type_map) == 0:
Expand Down Expand Up @@ -472,18 +478,12 @@ def wrap_up_4(xx):
return 4 * ((int(xx) + 3) // 4)


def update_one_sel(jdata, descriptor):
if descriptor["type"] == "loc_frame":
return descriptor
def update_one_sel(jdata, descriptor, one_type: bool = False):
rcut = descriptor["rcut"]
tmp_sel = get_sel(
jdata,
rcut,
one_type=descriptor["type"]
in (
"se_atten",
"se_atten_v2",
),
one_type=one_type,
)
sel = descriptor["sel"]
if isinstance(sel, int):
Expand All @@ -503,10 +503,7 @@ def update_one_sel(jdata, descriptor):
"not less than %d, but you set it to %d. The accuracy"
" of your model may get worse." % (ii, tt, dd)
)
if descriptor["type"] in (
"se_atten",
"se_atten_v2",
):
if one_type:
descriptor["sel"] = sel = sum(sel)
return descriptor

Expand All @@ -515,18 +512,6 @@ def update_sel(jdata):
log.info(
"Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)"
)
if jdata["model"].get("type") == "pairwise_dprc":
# do not update sel; only find min distance
rcut = get_rcut(jdata)
get_min_nbor_dist(jdata, rcut)
return jdata
elif jdata["model"].get("type") in ("linear_ener", "frozen"):
return jdata
descrpt_data = jdata["model"]["descriptor"]
if descrpt_data["type"] == "hybrid":
for ii in range(len(descrpt_data["list"])):
descrpt_data["list"][ii] = update_one_sel(jdata, descrpt_data["list"][ii])
else:
descrpt_data = update_one_sel(jdata, descrpt_data)
jdata["model"]["descriptor"] = descrpt_data
return jdata
jdata_cpy = jdata.copy()
jdata_cpy["model"] = Model.update_sel(jdata, jdata["model"])
return jdata_cpy
14 changes: 14 additions & 0 deletions deepmd/model/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,17 @@ def enable_compression(self, suffix: str = "") -> None:
def get_type_map(self) -> list:
"""Get the type map."""
return self.model.get_type_map()

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
# we don't know how to compress it, so no neighbor statistics here
return local_jdata
18 changes: 18 additions & 0 deletions deepmd/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,24 @@ def get_type_map(self) -> list:
"""Get the type map."""
return self.models[0].get_type_map()

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["models"] = [
Model.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["models"]
]
return local_jdata_cpy


class LinearEnergyModel(LinearModel):
"""Linear energy model make linear combinations of several existing energy models."""
Expand Down
106 changes: 79 additions & 27 deletions deepmd/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,47 @@ class Model(ABC):
Compression information for internal use
"""

@classmethod
def get_class_by_input(cls, input: dict):
"""Get the class by input data.

Parameters
----------
input : dict
The input data
"""
# infer model type by fitting_type
from deepmd.model.frozen import (
FrozenModel,
)
from deepmd.model.linear import (
LinearEnergyModel,
)
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)

model_type = input.get("type", "standard")
if model_type == "standard":
return StandardModel
elif model_type == "multi":
return MultiModel
elif model_type == "pairwise_dprc":
return PairwiseDPRc
elif model_type == "frozen":
return FrozenModel
elif model_type == "linear_ener":
return LinearEnergyModel
else:
raise ValueError(f"unknown model type: {model_type}")

def __new__(cls, *args, **kwargs):
if cls is Model:
# init model
# infer model type by fitting_type
from deepmd.model.frozen import (
FrozenModel,
)
from deepmd.model.linear import (
LinearEnergyModel,
)
from deepmd.model.multi import (
MultiModel,
)
from deepmd.model.pairwise_dprc import (
PairwiseDPRc,
)

model_type = kwargs.get("type", "standard")
if model_type == "standard":
cls = StandardModel
elif model_type == "multi":
cls = MultiModel
elif model_type == "pairwise_dprc":
cls = PairwiseDPRc
elif model_type == "frozen":
cls = FrozenModel
elif model_type == "linear_ener":
cls = LinearEnergyModel
else:
raise ValueError(f"unknown model type: {model_type}")
cls = cls.get_class_by_input(kwargs)
return cls.__new__(cls, *args, **kwargs)
return super().__new__(cls)

Expand Down Expand Up @@ -471,6 +482,30 @@ def get_feed_dict(
feed_dict["t_aparam:0"] = kwargs["aparam"]
return feed_dict

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
"""Update the selection and perform neighbor statistics.

Notes
-----
Do not modify the input data without copying it.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
"""
cls = cls.get_class_by_input(local_jdata)
return cls.update_sel(global_jdata, local_jdata)


class StandardModel(Model):
"""Standard model, which must contain a descriptor and a fitting.
Expand Down Expand Up @@ -613,3 +648,20 @@ def get_rcut(self) -> float:
def get_ntypes(self) -> int:
"""Get the number of types."""
return self.ntypes

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["descriptor"] = Descriptor.update_sel(
global_jdata, local_jdata["descriptor"]
)
return local_jdata_cpy
Loading