Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add zbl training #3398

Merged
merged 71 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 70 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
063f8e7
feat: add zbl training
anyangml Mar 3, 2024
8f06ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
2f7fa77
fix: add atom bias
anyangml Mar 3, 2024
672563c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
312973a
Merge branch 'devel' into devel
anyangml Mar 3, 2024
cf66829
Merge branch 'devel' into devel
anyangml Mar 3, 2024
52ab95f
chore: refactor
anyangml Mar 3, 2024
993efe9
fix: add pairtab stat
anyangml Mar 3, 2024
897f9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
e8320b6
Merge branch 'devel' into devel
anyangml Mar 3, 2024
701cb55
fix: add UTs
anyangml Mar 3, 2024
e27a816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
dc30bbd
fix: add UT input
anyangml Mar 3, 2024
a232cf3
fix: UTs
anyangml Mar 3, 2024
d9856e7
Merge branch 'devel' into devel
anyangml Mar 3, 2024
004b63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
ca99701
fix: UTs
anyangml Mar 3, 2024
162fc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
9c25175
fix: UTs
anyangml Mar 3, 2024
8fc3a70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
84fb816
chore: merge conflict
anyangml Mar 3, 2024
55e2b7f
fix: update numpy shape
anyangml Mar 3, 2024
0b9f7ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
6524694
fix: UTs
anyangml Mar 3, 2024
e3d9a7b
feat: add UTs
anyangml Mar 3, 2024
e648ab4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
7143aa9
Merge branch 'devel' into devel
anyangml Mar 4, 2024
6ed8fde
fix: UTs
anyangml Mar 4, 2024
aadddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f36988d
fix: UTs
anyangml Mar 4, 2024
9c9cbbe
feat: update UTs
anyangml Mar 4, 2024
d2adebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
7071608
Merge branch 'devel' into devel
anyangml Mar 4, 2024
00c877c
fix: UTs
anyangml Mar 4, 2024
eb36de2
Merge branch 'devel' into devel
anyangml Mar 4, 2024
5de7214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
8b35fa4
rix: revert abstract method
anyangml Mar 4, 2024
bbc7ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
c384b3b
fix: UTs
anyangml Mar 4, 2024
1d5fad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
dc407e3
chore: refactor
anyangml Mar 4, 2024
a9f65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
18a4897
fix: precommit
anyangml Mar 4, 2024
94bea6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
09f9352
fix: precommit
anyangml Mar 4, 2024
a63089d
fix: UTs
anyangml Mar 4, 2024
bda547d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
e6be71b
fix: UTs
anyangml Mar 4, 2024
3482ef2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
a2afe7c
Merge branch 'devel' into devel
anyangml Mar 4, 2024
f067c4c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
b0e4749
Merge branch 'devel' into devel
anyangml Mar 5, 2024
f8e340a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
3610b3d
feat: add atype remap
anyangml Mar 5, 2024
caf5f78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
7d4e49c
fix: add UTs
anyangml Mar 5, 2024
a30bc35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
d436444
fix: UTs
anyangml Mar 5, 2024
af1349c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
ba643b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
b60f038
fix: update numpy
anyangml Mar 5, 2024
e5a905b
chore:skip test
anyangml Mar 5, 2024
25f1ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
85baa59
chore: rename class
anyangml Mar 5, 2024
a29541c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
1b04a43
Merge branch 'devel' into devel
anyangml Mar 6, 2024
18ec6a5
fix: add TODO
anyangml Mar 6, 2024
d62bacb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
c012b3f
chore: refactor remap
anyangml Mar 6, 2024
a0d7caf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
edd9a10
fix: UTs
anyangml Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepmd/dpmodel/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .make_base_atomic_model import (
make_base_atomic_model,
Expand All @@ -37,6 +37,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
62 changes: 45 additions & 17 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)


class LinearAtomicModel(BaseAtomicModel):
class LinearEnergyAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.

Parameters
Expand All @@ -59,14 +59,16 @@
self.models = models
sub_model_type_maps = [md.get_type_map() for md in models]
err_msg = []
self.mapping_list = []

Check warning on line 62 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L62

Added line #L62 was not covered by tests
common_type_map = set(type_map)
self.type_map = type_map

Check warning on line 64 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L64

Added line #L64 was not covered by tests
for tpmp in sub_model_type_maps:
if not common_type_map.issubset(set(tpmp)):
err_msg.append(
f"type_map {tpmp} is not a subset of type_map {type_map}"
)
self.mapping_list.append(self.remap_atype(tpmp, self.type_map))

Check warning on line 70 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L70

Added line #L70 was not covered by tests
assert len(err_msg) == 0, "\n".join(err_msg)
self.type_map = type_map
self.mixed_types_list = [model.mixed_types() for model in self.models]
super().__init__(**kwargs)

Expand Down Expand Up @@ -163,17 +165,20 @@
self.mixed_types_list, raw_nlists, self.get_model_sels()
)
]
ener_list = [
model.forward_atomic(
extended_coord,
extended_atype,
nl,
mapping,
fparam,
aparam,
)["energy"]
for model, nl in zip(self.models, nlists_)
]
ener_list = []

Check warning on line 168 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L168

Added line #L168 was not covered by tests

for i, model in enumerate(self.models):
mapping = self.mapping_list[i]
ener_list.append(

Check warning on line 172 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L170-L172

Added lines #L170 - L172 were not covered by tests
model.forward_atomic(
extended_coord,
mapping[extended_atype],
nlists_[i],
mapping,
fparam,
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
self.atomic_bias = None
if self.atomic_bias is not None:
Expand All @@ -184,6 +189,29 @@
} # (nframes, nloc, 1)
return fit_ret

@staticmethod
def remap_atype(ori_map: List[str], new_map: List[str]) -> np.ndarray:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.

Parameters
----------
ori_map : List[str]
The original type map of an AtomicModel.
new_map : List[str]
The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.

Returns
-------
np.ndarray
"""
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}

Check warning on line 210 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L210

Added line #L210 was not covered by tests
# this maps the atype in the new map to the original map
mapping = np.array([type_2_idx[new_map[idx]] for idx in range(len(new_map))])
return mapping

Check warning on line 213 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L212-L213

Added lines #L212 - L213 were not covered by tests

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down Expand Up @@ -261,7 +289,7 @@
return False


class DPZBLLinearAtomicModel(LinearAtomicModel):
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.

Parameters
Expand Down Expand Up @@ -308,7 +336,7 @@
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize(
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
Expand All @@ -319,7 +347,7 @@
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
Expand All @@ -328,7 +356,7 @@
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

([dp_model, zbl_model], type_map) = LinearAtomicModel.deserialize(
([dp_model, zbl_model], type_map) = LinearEnergyAtomicModel.deserialize(

Check warning on line 359 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L359

Added line #L359 was not covered by tests
data.pop("models")
)

Expand Down
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
pass

Check warning on line 59 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L59

Added line #L59 was not covered by tests

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)
return len(self.get_type_map())

Check warning on line 63 in deepmd/dpmodel/atomic_model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/make_base_atomic_model.py#L63

Added line #L63 was not covered by tests

@abstractmethod
def get_sel(self) -> List[int]:
Expand Down
11 changes: 10 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@
self.type_map = type_map

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map
self.ntypes = len(type_map)

Check warning on line 70 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L69-L70

Added lines #L69 - L70 were not covered by tests

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(

Check warning on line 77 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L74-L77

Added lines #L74 - L77 were not covered by tests
"The `type_map` provided does not match the number of columns in the table."
)
else:
self.tab_info, self.tab_data = None, None

Expand Down Expand Up @@ -145,7 +153,8 @@
tab_model = cls(None, rcut, sel, type_map, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
nspline, ntypes = tab_model.tab_info[-2:].astype(int)
tab_model.tab_data = tab_model.tab.tab_data.reshape(ntypes, ntypes, nspline, 4)

Check warning on line 157 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L156-L157

Added lines #L156 - L157 were not covered by tests
return tab_model

def forward_atomic(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
Expand All @@ -32,6 +32,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
Loading
Loading