Skip to content

Commit

Permalink
refactor: refactor update_sel and save min_nbor_dist (#3829)
Browse files Browse the repository at this point in the history
Fix #3525. Fix
#3544.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced `update_sel` method to accept additional parameters and
return more detailed data, improving model selection and neighbor
statistics.
- **Bug Fixes**
- Improved handling and processing of training data to enhance model
accuracy.
- **Refactor**
- Updated method signatures and logic for consistency and better
performance.
- **Chores**
  - Removed unused `hook` method to streamline codebase.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
njzjz and wanghan-iapcm authored May 31, 2024
1 parent 84b711e commit 3a7fbcf
Show file tree
Hide file tree
Showing 44 changed files with 1,351 additions and 335 deletions.
29 changes: 25 additions & 4 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -524,18 +527,36 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, True)
min_nbor_dist, sel = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], True
)
local_jdata_cpy["sel"] = sel[0]
return local_jdata_cpy, min_nbor_dist


@DescriptorBlock.register("se_atten")
Expand Down
47 changes: 33 additions & 14 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -743,30 +746,46 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
update_sel = UpdateSel()
local_jdata_cpy["repinit"] = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy["repinit"],
min_nbor_dist, repinit_sel = update_sel.update_one_sel(
train_data,
type_map,
local_jdata_cpy["repinit"]["rcut"],
local_jdata_cpy["repinit"]["nsel"],
True,
rcut_key="rcut",
sel_key="nsel",
)
local_jdata_cpy["repformer"] = update_sel.update_one_sel(
global_jdata,
local_jdata_cpy["repformer"],
local_jdata_cpy["repinit"]["nsel"] = repinit_sel[0]
min_nbor_dist, repformer_sel = update_sel.update_one_sel(
train_data,
type_map,
local_jdata_cpy["repformer"]["rcut"],
local_jdata_cpy["repformer"]["nsel"],
True,
rcut_key="rcut",
sel_key="nsel",
)
return local_jdata_cpy
local_jdata_cpy["repformer"]["nsel"] = repformer_sel[0]
return local_jdata_cpy, min_nbor_dist
40 changes: 32 additions & 8 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Dict,
List,
Optional,
Tuple,
Union,
)

Expand All @@ -19,6 +20,9 @@
from deepmd.dpmodel.utils.nlist import (
nlist_distinguish_types,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -222,22 +226,42 @@ def call(
return out_descriptor, out_gr, out_g2, out_h2, out_sw

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [
BaseDescriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy
new_list = []
min_nbor_dist = None
for sub_jdata in local_jdata["list"]:
new_sub_jdata, min_nbor_dist_ = BaseDescriptor.update_sel(
train_data, type_map, sub_jdata
)
if min_nbor_dist_ is not None:
min_nbor_dist = min_nbor_dist_
new_list.append(new_sub_jdata)
local_jdata_cpy["list"] = new_list
return local_jdata_cpy, min_nbor_dist

def serialize(self) -> dict:
return {
Expand Down
26 changes: 22 additions & 4 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
Callable,
List,
Optional,
Tuple,
Union,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -149,19 +153,33 @@ def deserialize(cls, data: dict) -> "BD":

@classmethod
@abstractmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
# call subprocess
cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__))
return cls.update_sel(global_jdata, local_jdata)
return cls.update_sel(train_data, type_map, local_jdata)

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
Expand Down
28 changes: 24 additions & 4 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -422,15 +425,32 @@ def deserialize(cls, data: dict) -> "DescrptSeA":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
29 changes: 25 additions & 4 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand All @@ -21,6 +24,7 @@
Any,
List,
Optional,
Tuple,
)

from deepmd.dpmodel import (
Expand Down Expand Up @@ -345,15 +349,32 @@ def deserialize(cls, data: dict) -> "DescrptSeR":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
28 changes: 24 additions & 4 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.path import (
DPPath,
)
Expand Down Expand Up @@ -348,15 +351,32 @@ def deserialize(cls, data: dict) -> "DescrptSeT":
return obj

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict):
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[List[str]],
local_jdata: dict,
) -> Tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
global_jdata : dict
The global data, containing the training section
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
return UpdateSel().update_one_sel(global_jdata, local_jdata_cpy, False)
min_nbor_dist, local_jdata_cpy["sel"] = UpdateSel().update_one_sel(
train_data, type_map, local_jdata_cpy["rcut"], local_jdata_cpy["sel"], False
)
return local_jdata_cpy, min_nbor_dist
Loading

0 comments on commit 3a7fbcf

Please sign in to comment.