Skip to content

Commit

Permalink
Feat: add zbl training (#3398)
Browse files Browse the repository at this point in the history
Signed-off-by: Anyang Peng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Mar 6, 2024
1 parent 278e6b8 commit df20f4c
Show file tree
Hide file tree
Showing 17 changed files with 539 additions and 130 deletions.
8 changes: 4 additions & 4 deletions deepmd/dpmodel/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .make_base_atomic_model import (
make_base_atomic_model,
Expand All @@ -37,6 +37,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
67 changes: 51 additions & 16 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)


class LinearAtomicModel(BaseAtomicModel):
class LinearEnergyAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.
Parameters
Expand Down Expand Up @@ -163,17 +163,21 @@ def forward_atomic(
self.mixed_types_list, raw_nlists, self.get_model_sels()
)
]
ener_list = [
model.forward_atomic(
extended_coord,
extended_atype,
nl,
mapping,
fparam,
aparam,
)["energy"]
for model, nl in zip(self.models, nlists_)
]
ener_list = []

for i, model in enumerate(self.models):
ori_map = model.get_type_map()
updated_atype = self.remap_atype(extended_atype, ori_map, self.type_map)
ener_list.append(
model.forward_atomic(
extended_coord,
updated_atype,
nlists_[i],
mapping,
fparam,
aparam,
)["energy"]
)
self.weights = self._compute_weight(extended_coord, extended_atype, nlists_)
self.atomic_bias = None
if self.atomic_bias is not None:
Expand All @@ -184,6 +188,37 @@ def forward_atomic(
} # (nframes, nloc, 1)
return fit_ret

@staticmethod
def remap_atype(
atype: np.ndarray, ori_map: List[str], new_map: List[str]
) -> np.ndarray:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.
Parameters
----------
atype : np.ndarray
The atom type tensor being updated, shape of (nframes, natoms)
ori_map : List[str]
The original type map of an AtomicModel.
new_map : List[str]
The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.
Returns
-------
np.ndarray
"""
assert np.max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = np.array([type_2_idx[new_map[idx]] for idx in range(len(new_map))])
updated_atype = mapping[atype]
return updated_atype

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down Expand Up @@ -261,7 +296,7 @@ def is_aparam_nall(self) -> bool:
return False


class DPZBLLinearAtomicModel(LinearAtomicModel):
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.
Parameters
Expand Down Expand Up @@ -308,7 +343,7 @@ def serialize(self) -> dict:
"@class": "Model",
"type": "zbl",
"@version": 1,
"models": LinearAtomicModel.serialize(
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
Expand All @@ -319,7 +354,7 @@ def serialize(self) -> dict:
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("@class")
Expand All @@ -328,7 +363,7 @@ def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

([dp_model, zbl_model], type_map) = LinearAtomicModel.deserialize(
([dp_model, zbl_model], type_map) = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)

Expand Down
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@ def get_rcut(self) -> float:
pass

@abstractmethod
def get_type_map(self) -> Optional[List[str]]:
def get_type_map(self) -> List[str]:
"""Get the type map."""
pass

def get_ntypes(self) -> int:
"""Get the number of atom types."""
tmap = self.get_type_map()
if tmap is not None:
return len(tmap)
else:
raise ValueError(
"cannot infer the number of types from a None type map"
)
return len(self.get_type_map())

@abstractmethod
def get_sel(self) -> List[int]:
Expand Down
11 changes: 10 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@ def __init__(
self.type_map = type_map

self.tab = PairTab(self.tab_file, rcut=rcut)
self.type_map = type_map
self.ntypes = len(type_map)

if self.tab_file is not None:
self.tab_info, self.tab_data = self.tab.get()
nspline, ntypes_tab = self.tab_info[-2:].astype(int)
self.tab_data = self.tab_data.reshape(ntypes_tab, ntypes_tab, nspline, 4)
if self.ntypes != ntypes_tab:
raise ValueError(
"The `type_map` provided does not match the number of columns in the table."
)
else:
self.tab_info, self.tab_data = None, None

Expand Down Expand Up @@ -145,7 +153,8 @@ def deserialize(cls, data) -> "PairTabAtomicModel":
tab_model = cls(None, rcut, sel, type_map, **data)
tab_model.tab = tab
tab_model.tab_info = tab_model.tab.tab_info
tab_model.tab_data = tab_model.tab.tab_data
nspline, ntypes = tab_model.tab_info[-2:].astype(int)
tab_model.tab_data = tab_model.tab.tab_data.reshape(ntypes, ntypes, nspline, 4)
return tab_model

def forward_atomic(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
DPAtomicModel,
)
from .linear_atomic_model import (
DPZBLLinearAtomicModel,
LinearAtomicModel,
DPZBLLinearEnergyAtomicModel,
LinearEnergyAtomicModel,
)
from .pairtab_atomic_model import (
PairTabAtomicModel,
Expand All @@ -32,6 +32,6 @@
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"LinearEnergyAtomicModel",
"DPZBLLinearEnergyAtomicModel",
]
105 changes: 89 additions & 16 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
get_multiple_nlist_key,
nlist_distinguish_types,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand All @@ -40,7 +43,7 @@
)


class LinearAtomicModel(torch.nn.Module, BaseAtomicModel):
class LinearEnergyAtomicModel(torch.nn.Module, BaseAtomicModel):
"""Linear model make linear combinations of several existing models.
Parameters
Expand Down Expand Up @@ -117,8 +120,8 @@ def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
torch.tensor(rcuts, device=device),
torch.tensor(nsels, device=device),
rcuts,
nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -185,10 +188,12 @@ def forward_atomic(
ener_list = []

for i, model in enumerate(self.models):
ori_map = model.get_type_map()
updated_atype = self.remap_atype(extended_atype, ori_map, self.type_map)
ener_list.append(
model.forward_atomic(
extended_coord,
extended_atype,
updated_atype,
nlists_[i],
mapping,
fparam,
Expand All @@ -198,16 +203,57 @@ def forward_atomic(

weights = self._compute_weight(extended_coord, extended_atype, nlists_)

if self.atomic_bias is not None:
raise NotImplementedError("Need to add bias in a future PR.")
else:
fit_ret = {
"energy": torch.sum(
torch.stack(ener_list) * torch.stack(weights), dim=0
),
} # (nframes, nloc, 1)
atype = extended_atype[:, :nloc]
for idx, model in enumerate(self.models):
# TODO: provide interfaces for atomic models to access bias_atom_e
if isinstance(model, DPAtomicModel):
bias_atom_e = model.fitting_net.bias_atom_e
elif isinstance(model, PairTabAtomicModel):
bias_atom_e = model.bias_atom_e
else:
bias_atom_e = None
if bias_atom_e is not None:
ener_list[idx] += bias_atom_e[atype]

fit_ret = {
"energy": torch.sum(torch.stack(ener_list) * torch.stack(weights), dim=0),
} # (nframes, nloc, 1)
return fit_ret

@staticmethod
def remap_atype(
atype: torch.Tensor, ori_map: List[str], new_map: List[str]
) -> torch.Tensor:
"""
This method is used to map the atype from the common type_map to the original type_map of
indivial AtomicModels.
Parameters
----------
atype : torch.Tensor
The atom type tensor being updated, shape of (nframes, natoms)
ori_map : List[str]
The original type map of an AtomicModel.
new_map : List[str]
The common type map of the DPZBLLinearEnergyAtomicModel, created by the `get_type_map` method,
must be a subset of the ori_map.
Returns
-------
torch.Tensor
"""
assert torch.max(atype) < len(
new_map
), "The input `atype` cannot be handled by the type_map."
type_2_idx = {atp: idx for idx, atp in enumerate(ori_map)}
# this maps the atype in the new map to the original map
mapping = torch.tensor(
[type_2_idx[new_map[idx]] for idx in range(len(new_map))],
device=atype.device,
)
updated_atype = mapping[atype.long()]
return updated_atype

def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
Expand Down Expand Up @@ -292,7 +338,7 @@ def is_aparam_nall(self) -> bool:
return False


class DPZBLLinearAtomicModel(LinearAtomicModel):
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.
Parameters
Expand Down Expand Up @@ -336,14 +382,41 @@ def __init__(
# this is a placeholder being updated in _compute_weight, to handle Jit attribute init error.
self.zbl_weight = torch.empty(0, dtype=torch.float64, device=env.DEVICE)

def compute_or_load_stat(
self,
sampled_func,
stat_file_path: Optional[DPPath] = None,
):
"""
Compute or load the statistics parameters of the model,
such as mean and standard deviation of descriptors or the energy bias of the fitting net.
When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
and saved in the `stat_file_path`(s).
When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
and load the calculated statistics parameters.
Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data systems.
stat_file_path
The dictionary of paths to the statistics files.
"""
self.dp_model.compute_or_load_stat(sampled_func, stat_file_path)
self.zbl_model.compute_or_load_stat(sampled_func, stat_file_path)

def change_energy_bias(self):
# need to implement
pass

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
{
"@class": "Model",
"@version": 1,
"type": "zbl",
"models": LinearAtomicModel.serialize(
"models": LinearEnergyAtomicModel.serialize(
[self.dp_model, self.zbl_model], self.type_map
),
"sw_rmin": self.sw_rmin,
Expand All @@ -354,14 +427,14 @@ def serialize(self) -> dict:
return dd

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
def deserialize(cls, data) -> "DPZBLLinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
sw_rmin = data.pop("sw_rmin")
sw_rmax = data.pop("sw_rmax")
smin_alpha = data.pop("smin_alpha")

[dp_model, zbl_model], type_map = LinearAtomicModel.deserialize(
[dp_model, zbl_model], type_map = LinearEnergyAtomicModel.deserialize(
data.pop("models")
)

Expand Down
Loading

0 comments on commit df20f4c

Please sign in to comment.