Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add zbl training #3398

Merged
merged 71 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
063f8e7
feat: add zbl training
anyangml Mar 3, 2024
8f06ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
2f7fa77
fix: add atom bias
anyangml Mar 3, 2024
672563c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
312973a
Merge branch 'devel' into devel
anyangml Mar 3, 2024
cf66829
Merge branch 'devel' into devel
anyangml Mar 3, 2024
52ab95f
chore: refactor
anyangml Mar 3, 2024
993efe9
fix: add pairtab stat
anyangml Mar 3, 2024
897f9f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
e8320b6
Merge branch 'devel' into devel
anyangml Mar 3, 2024
701cb55
fix: add UTs
anyangml Mar 3, 2024
e27a816
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
dc30bbd
fix: add UT input
anyangml Mar 3, 2024
a232cf3
fix: UTs
anyangml Mar 3, 2024
d9856e7
Merge branch 'devel' into devel
anyangml Mar 3, 2024
004b63e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
ca99701
fix: UTs
anyangml Mar 3, 2024
162fc16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
9c25175
fix: UTs
anyangml Mar 3, 2024
8fc3a70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
84fb816
chore: merge conflict
anyangml Mar 3, 2024
55e2b7f
fix: update numpy shape
anyangml Mar 3, 2024
0b9f7ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
6524694
fix: UTs
anyangml Mar 3, 2024
e3d9a7b
feat: add UTs
anyangml Mar 3, 2024
e648ab4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2024
7143aa9
Merge branch 'devel' into devel
anyangml Mar 4, 2024
6ed8fde
fix: UTs
anyangml Mar 4, 2024
aadddcb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f36988d
fix: UTs
anyangml Mar 4, 2024
9c9cbbe
feat: update UTs
anyangml Mar 4, 2024
d2adebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
7071608
Merge branch 'devel' into devel
anyangml Mar 4, 2024
00c877c
fix: UTs
anyangml Mar 4, 2024
eb36de2
Merge branch 'devel' into devel
anyangml Mar 4, 2024
5de7214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
8b35fa4
rix: revert abstract method
anyangml Mar 4, 2024
bbc7ad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
c384b3b
fix: UTs
anyangml Mar 4, 2024
1d5fad0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
dc407e3
chore: refactor
anyangml Mar 4, 2024
a9f65be
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
18a4897
fix: precommit
anyangml Mar 4, 2024
94bea6a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
09f9352
fix: precommit
anyangml Mar 4, 2024
a63089d
fix: UTs
anyangml Mar 4, 2024
bda547d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
e6be71b
fix: UTs
anyangml Mar 4, 2024
3482ef2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
a2afe7c
Merge branch 'devel' into devel
anyangml Mar 4, 2024
f067c4c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
b0e4749
Merge branch 'devel' into devel
anyangml Mar 5, 2024
f8e340a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
3610b3d
feat: add atype remap
anyangml Mar 5, 2024
caf5f78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
7d4e49c
fix: add UTs
anyangml Mar 5, 2024
a30bc35
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
d436444
fix: UTs
anyangml Mar 5, 2024
af1349c
Merge branch 'devel' into devel
anyangml Mar 5, 2024
ba643b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
b60f038
fix: update numpy
anyangml Mar 5, 2024
e5a905b
chore:skip test
anyangml Mar 5, 2024
25f1ff8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
85baa59
chore: rename class
anyangml Mar 5, 2024
a29541c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2024
1b04a43
Merge branch 'devel' into devel
anyangml Mar 6, 2024
18ec6a5
fix: add TODO
anyangml Mar 6, 2024
d62bacb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
c012b3f
chore: refactor remap
anyangml Mar 6, 2024
a0d7caf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
edd9a10
fix: UTs
anyangml Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions deepmd/dpmodel/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .make_base_atomic_model import (
make_base_atomic_model,
Expand All @@ -37,6 +37,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
67 changes: 51 additions & 16 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)


class LinearAtomicModel(BaseAtomicModel):
class LinearEnergyAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.

Parameters
Expand Down Expand Up @@ -163,17 +163,21 @@ def forward_atomic(
self.mixed_types_list, raw_nlists, self.get_model_sels()
)
]
ener_list = [
model.forward_atomic(
extended_coord,
extended_atype,
nl,
mapping,
fparam,
aparam,
)["energy"]
for model, nl in zip(self.models, nlists_)
]
ener_list = []

for i, model in enumerate(self.models):
ori_map = model.get_type_map()
updated_atype = self.remap_atype(extended_atype, ori_map, self.type_map)
ener_list.append(
model.forward_atomic(
extended_coord,
updated_atype,
nlists_[i],
mapping,
fparam,
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
self.atomic_bias = None
if self.atomic_bias is not None:
Expand All @@ -184,6 +188,37 @@ def forward_atomic(
} # (nframes, nloc, 1)
return fit_ret

@staticmethod
def remap_atype(
atype: np.ndarray, ori_map: List[str], new_map: List[str]
) -> np.ndarray:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.

Parameters
----------
atype : np.ndarray
The atom type tensor being updated, shape of (nframes, natoms)
ori_map : List[str]
The original type map of an AtomicModel.
new_map : List[str]
The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.

Returns
-------
np.ndarray
"""
assert np.max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = np.array([type_2_idx[new_map[idx]] for idx in range(len(new_map))])
updated_atype = mapping[atype]
return updated_atype

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down Expand Up @@ -261,7 +296,7 @@ def is_aparam_nall(self) -> bool:
return False


class DPZBLLinearAtomicModel(LinearAtomicModel):
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.

Parameters
Expand Down Expand Up @@ -308,7 +343,7 @@ def serialize(self) -> dict:
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize(
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
Expand All @@ -319,7 +354,7 @@ def serialize(self) -> dict:
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
Expand All @@ -328,7 +363,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

([dp_model, zbl_model], type_map) = LinearAtomicModel.deserialize(
([dp_model, zbl_model], type_map) = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)

Expand Down
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@ def get_rcut(self) -> float:
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
pass

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)
return len(self.get_type_map())

@abstractmethod
def get_sel(self) -> List[int]:
Expand Down
11 changes: 10 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ def __init__(
self.type_map = type_map

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map
self.ntypes = len(type_map)

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(
"The `type_map` provided does not match the number of columns in the table."
)
else:
self.tab_info, self.tab_data = None, None

Expand Down Expand Up @@ -145,7 +153,8 @@ def deserialize(cls, data) -> "PairTabAtomicModel":
tab_model = cls(None, rcut, sel, type_map, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
nspline, ntypes = tab_model.tab_info[-2:].astype(int)
tab_model.tab_data = tab_model.tab.tab_data.reshape(ntypes, ntypes, nspline, 4)
return tab_model

def forward_atomic(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
Expand All @@ -32,6 +32,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
105 changes: 88 additions & 17 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
get_multiple_nlist_key,
nlist_distinguish_types,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand All @@ -40,7 +43,7 @@
)


class LinearAtomicModel(torch.nn.Module, BaseAtomicModel):
class LinearEnergyAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Linear model make linear combinations of several existing models.

Parameters
Expand Down Expand Up @@ -93,7 +96,6 @@ def get_rcut(self) -> float:

@torch.jit.export
def get_type_map(self) -> List[str]:
"""Get the type map."""
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return self.type_map

def get_model_rcuts(self) -> List[float]:
Expand All @@ -117,8 +119,8 @@ def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
rcuts,
nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -185,10 +187,12 @@ def forward_atomic(
ener_list = []

for i, model in enumerate(self.models):
ori_map = model.get_type_map()
updated_atype = self.remap_atype(extended_atype, ori_map, self.type_map)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
ener_list.append(
model.forward_atomic(
extended_coord,
extended_atype,
updated_atype,
nlists_[i],
mapping,
fparam,
Expand All @@ -198,16 +202,56 @@ def forward_atomic(

weights = self._compute_weight(extended_coord, extended_atype, nlists_)

if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")
else:
fit_ret = {
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(weights), dim=0
),
} # (nframes, nloc, 1)
atype = extended_atype[:, :nloc]
for idx, model in enumerate(self.models):
if isinstance(model, DPAtomicModel):
bias_atom_e = model.fitting_net.bias_atom_e
elif isinstance(model, PairTabAtomicModel):
bias_atom_e = model.bias_atom_e
anyangml marked this conversation as resolved.
Show resolved Hide resolved
else:
bias_atom_e = None
if bias_atom_e is not None:
ener_list[idx] += bias_atom_e[atype]

fit_ret = {
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
} # (nframes, nloc, 1)
return fit_ret

@staticmethod
def remap_atype(
atype: torch.Tensor, ori_map: List[str], new_map: List[str]
) -> torch.Tensor:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.

Parameters
----------
atype : torch.Tensor
The atom type tensor being updated, shape of (nframes, natoms)
ori_map : List[str]
The original type map of an AtomicModel.
new_map : List[str]
The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.

Returns
-------
torch.Tensor
"""
assert torch.max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = torch.tensor(
[type_2_idx[new_map[idx]] for idx in range(len(new_map))],
device=atype.device,
)
updated_atype = mapping[atype.long()]
return updated_atype

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down Expand Up @@ -292,7 +336,7 @@ def is_aparam_nall(self) -> bool:
return False


class DPZBLLinearAtomicModel(LinearAtomicModel):
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.

Parameters
Expand Down Expand Up @@ -336,14 +380,41 @@ def __init__(
# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.

Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
self.dp_model.compute_or_load_stat(sampled_func, stat_file_path)
self.zbl_model.compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"@version": 1,
"type": "zbl",
"models": LinearAtomicModel.serialize(
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
Expand All @@ -354,14 +425,14 @@ def serialize(self) -> dict:
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

[dp_model, zbl_model], type_map = LinearAtomicModel.deserialize(
[dp_model, zbl_model], type_map = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)

Expand Down
Loading
Loading