Skip to content

Commit

Permalink
Feat (pt): Expose Linear Ener Model (#4194)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced two new JSON configuration files for linear energy
calculations in water simulations.
- Launched the `LinearEnergyModel` class for advanced energy and force
calculations.
- Added a parameter for customizable model weighting in the linear
energy model.
- Expanded test suite with new test classes for validating linear energy
models.
- Added new model configurations and test classes to enhance testing
capabilities.

- **Bug Fixes**
- Corrected input handling in the deserialization method for version
compatibility.
	- Adjusted numerical values in data files for accurate testing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Oct 11, 2024
1 parent 3939786 commit 61f1681
Show file tree
Hide file tree
Showing 12 changed files with 622 additions and 12 deletions.
57 changes: 49 additions & 8 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class LinearEnergyAtomicModel(BaseAtomicModel):
type_map : list[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
weights : Optional[Union[str,list[float]]]
Weights of the models. If str, must be `sum` or `mean`. If list, must be a list of float.
"""

def __init__(
self,
models: list[BaseAtomicModel],
type_map: list[str],
weights: Optional[Union[str, list[float]]] = "mean",
**kwargs,
):
super().__init__(type_map, **kwargs)
Expand Down Expand Up @@ -89,6 +92,16 @@ def __init__(
)
self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE) # pylint: disable=no-explicit-dtype

if isinstance(weights, str):
assert weights in ["sum", "mean"]
elif isinstance(weights, list):
assert len(weights) == len(models)
else:
raise ValueError(
f"'weights' must be a string ('sum' or 'mean') or a list of float of length {len(models)}."
)
self.weights = weights

def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down Expand Up @@ -320,7 +333,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
data = copy.deepcopy(data)
check_version_compatibility(data.get("@version", 2), 2, 1)
check_version_compatibility(data.pop("@version", 2), 2, 1)
data.pop("@class", None)
data.pop("type", None)
models = [
Expand All @@ -331,16 +344,42 @@ def deserialize(cls, data: dict) -> "LinearEnergyAtomicModel":
return super().deserialize(data)

def _compute_weight(
self, extended_coord, extended_atype, nlists_
self,
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlists_: list[torch.Tensor],
) -> list[torch.Tensor]:
"""This should be a list of user defined weights that matches the number of models to be combined."""
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
return [
torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE)
/ nmodels
for _ in range(nmodels)
]
if isinstance(self.weights, str):
if self.weights == "sum":
return [
torch.ones(
(nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE
)
for _ in range(nmodels)
]
elif self.weights == "mean":
return [
torch.ones(
(nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE
)
/ nmodels
for _ in range(nmodels)
]
else:
raise ValueError(
"`weights` must be 'sum' or 'mean' when provided as a string."
)
elif isinstance(self.weights, list):
return [
torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE)
* w
for w in self.weights
]
else:
raise NotImplementedError

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand All @@ -365,7 +404,9 @@ def get_sel_type(self) -> list[int]:
return torch.unique(
torch.cat(
[
torch.as_tensor(model.get_sel_type(), dtype=torch.int32)
torch.as_tensor(
model.get_sel_type(), dtype=torch.int64, device=env.DEVICE
)
for model in self.models
]
)
Expand Down
62 changes: 62 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from .dos_model import (
DOSModel,
)
from .dp_linear_model import (
LinearEnergyModel,
)
from .dp_model import (
DPModelCommon,
)
Expand Down Expand Up @@ -105,6 +108,62 @@ def get_spin_model(model_params):
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_linear_model(model_params):
model_params = copy.deepcopy(model_params)
weights = model_params.get("weights", "mean")
list_of_models = []
ntypes = len(model_params["type_map"])
for sub_model_params in model_params["models"]:
if "descriptor" in sub_model_params:
# descriptor
sub_model_params["descriptor"]["ntypes"] = ntypes
sub_model_params["descriptor"]["type_map"] = copy.deepcopy(
model_params["type_map"]
)
descriptor = BaseDescriptor(**sub_model_params["descriptor"])
# fitting
fitting_net = sub_model_params.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
list_of_models.append(
DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
)

else: # must be pairtab
assert (
"type" in sub_model_params and sub_model_params["type"] == "pairtab"
), "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model"
list_of_models.append(
PairTabAtomicModel(
sub_model_params["tab_file"],
sub_model_params["rcut"],
sub_model_params["sel"],
type_map=model_params["type_map"],
)
)

atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
return LinearEnergyModel(
models=list_of_models,
type_map=model_params["type_map"],
weights=weights,
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)


def get_zbl_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
Expand Down Expand Up @@ -247,6 +306,8 @@ def get_model(model_params):
return get_zbl_model(model_params)
else:
return get_standard_model(model_params)
elif model_type == "linear_ener":
return get_linear_model(model_params)
else:
return BaseModel.get_class_by_type(model_type).get_model(model_params)

Expand All @@ -265,4 +326,5 @@ def get_model(model_params):
"DPZBLModel",
"make_model",
"make_hessian_model",
"LinearEnergyModel",
]
166 changes: 166 additions & 0 deletions deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from typing import (
Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
LinearEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPLinearModel_ = make_model(LinearEnergyAtomicModel)


@BaseModel.register("linear_ener")
class LinearEnergyModel(DPLinearModel_):
model_type = "ener"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_energy": deepcopy(out_def_data["energy"]),
"energy": deepcopy(out_def_data["energy_redu"]),
}
if self.do_grad_r("energy"):
output_def["force"] = deepcopy(out_def_data["energy_derv_r"])
output_def["force"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
if "mask" in out_def_data:
output_def["mask"] = deepcopy(out_def_data["mask"])
return output_def

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)

model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
return model_predict

@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
type_map = local_jdata_cpy["type_map"]
min_nbor_dist = None
for idx, sub_model in enumerate(local_jdata_cpy["models"]):
if "tab_file" not in sub_model:
sub_model, temp_min = DPModelCommon.update_sel(
train_data, type_map, local_jdata["models"][idx]
)
if min_nbor_dist is None or temp_min <= min_nbor_dist:
min_nbor_dist = temp_min
return local_jdata_cpy, min_nbor_dist
4 changes: 2 additions & 2 deletions doc/model/linear.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Linear model {{ tensorflow_icon }}
## Linear model {{ tensorflow_icon }} {{ pytorch_icon }}

:::{note}
**Supported backends**: TensorFlow {{ tensorflow_icon }}
**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }}
:::

One can linearly combine existing models with arbitrary coefficients:
Expand Down
2 changes: 1 addition & 1 deletion examples/water/d3/dftd3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,4 @@
9.700000000000001066e+00 -1.186747936398473687e-05 -7.637113677130612127e-06 -5.528293849956352819e-06
9.800000000000000711e+00 -1.114523618469756001e-05 -7.174288601187318493e-06 -5.194401230658985063e-06
9.900000000000000355e+00 -1.047381249252528874e-05 -6.743886368019750717e-06 -4.883815978498405921e-06
1.000000000000000000e+01 0.000000000000000e00e+00 0.000000000000000e00e+00 0.000000000000000e00e+00
1.000000000000000000e+01 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
Loading

0 comments on commit 61f1681

Please sign in to comment.