Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor update_sel and save min_nbor_dist #3829

Merged
merged 14 commits into from
May 31, 2024
29 changes: 25 additions & 4 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -524,18 +527,36 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)
min_nbor_dist, sel = UpdateSel().update_one_sel(

Check warning on line 555 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L555

Added line #L555 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
return local_jdata_cpy, min_nbor_dist

Check warning on line 559 in deepmd/dpmodel/descriptor/dpa1.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa1.py#L558-L559

Added lines #L558 - L559 were not covered by tests


@DescriptorBlock.register("se_atten")
Expand Down
47 changes: 33 additions & 14 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -743,30 +746,46 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
local_jdata_cpy["repinit"] = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy["repinit"],
min_nbor_dist, repinit_sel = update_sel.update_one_sel(

Check warning on line 775 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L775

Added line #L775 was not covered by tests
train_data,
type_map,
local_jdata_cpy["repinit"]["rcut"],
local_jdata_cpy["repinit"]["nsel"],
True,
rcut_key="rcut",
sel_key="nsel",
)
local_jdata_cpy["repformer"] = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy["repformer"],
local_jdata_cpy["repinit"]["nsel"] = repinit_sel[0]
min_nbor_dist, repformer_sel = update_sel.update_one_sel(

Check warning on line 783 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L782-L783

Added lines #L782 - L783 were not covered by tests
train_data,
type_map,
local_jdata_cpy["repformer"]["rcut"],
local_jdata_cpy["repformer"]["nsel"],
True,
rcut_key="rcut",
sel_key="nsel",
)
return local_jdata_cpy
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
return local_jdata_cpy, min_nbor_dist

Check warning on line 791 in deepmd/dpmodel/descriptor/dpa2.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/dpa2.py#L790-L791

Added lines #L790 - L791 were not covered by tests
40 changes: 32 additions & 8 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dict,
List,
Optional,
Tuple,
Union,
)

Expand All @@ -19,6 +20,9 @@
from deepmd.dpmodel.utils.nlist import (
nlist_distinguish_types,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -222,22 +226,42 @@
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [
BaseDescriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy
new_list = []
min_nbor_dist = None
for sub_jdata in local_jdata["list"]:
new_sub_jdata, min_nbor_dist_ = BaseDescriptor.update_sel(

Check warning on line 257 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L254-L257

Added lines #L254 - L257 were not covered by tests
train_data, type_map, sub_jdata
)
if min_nbor_dist_ is not None:
min_nbor_dist = min_nbor_dist_
new_list.append(new_sub_jdata)
local_jdata_cpy["list"] = new_list
return local_jdata_cpy, min_nbor_dist

Check warning on line 264 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L260-L264

Added lines #L260 - L264 were not covered by tests

def serialize(self) -> dict:
return {
Expand Down
26 changes: 22 additions & 4 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
Callable,
List,
Optional,
Tuple,
Union,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -149,19 +153,33 @@ def deserialize(cls, data: dict) -> "BD":

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
# call subprocess
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)
return cls.update_sel(train_data, type_map, local_jdata)
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
Expand Down
28 changes: 24 additions & 4 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -422,15 +425,32 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 453 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L453

Added line #L453 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 456 in deepmd/dpmodel/descriptor/se_e2_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_e2_a.py#L456

Added line #L456 was not covered by tests
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 25 additions & 4 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand All @@ -21,6 +24,7 @@
Any,
List,
Optional,
Tuple,
)

from deepmd.dpmodel import (
Expand Down Expand Up @@ -345,15 +349,32 @@
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(

Check warning on line 377 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L377

Added line #L377 was not covered by tests
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 380 in deepmd/dpmodel/descriptor/se_r.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/se_r.py#L380

Added line #L380 was not covered by tests
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading