Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (pt): Expose Linear Ener Model #4194

Merged
merged 31 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
e906fac
fix: zbl mix type model
anyangml Oct 8, 2024
91cf861
feat: add linear model
anyangml Oct 8, 2024
7d3044c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
082ab74
fix: dftd3 example
anyangml Oct 8, 2024
0104e18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
21580e1
feat: add pt example
anyangml Oct 8, 2024
08fcb55
fix: jit
anyangml Oct 8, 2024
739670e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
0522562
feat: add UTs
anyangml Oct 8, 2024
8b1cb8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
f2753e7
fix: UTs
anyangml Oct 8, 2024
63e7017
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
0514eae
fix: UTs
anyangml Oct 8, 2024
f41df5b
fix: sel type UT
anyangml Oct 8, 2024
000c1c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
189961c
fix: UT sel type dtype to long
anyangml Oct 8, 2024
9715641
fix: revert dtype change
anyangml Oct 8, 2024
c8e86fe
fix: revert ut change
anyangml Oct 8, 2024
a935784
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2024
169203b
Merge branch 'devel' into feat/expose-linear-model
anyangml Oct 9, 2024
b664e55
fix: rename, fix UT device
anyangml Oct 9, 2024
8f06bb5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
af16e65
change get_sel_type dtype to int64
anyangml Oct 9, 2024
34e3c97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
34843a6
feat: add test training
anyangml Oct 9, 2024
1579a7e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
11350e2
fix: revert changes
anyangml Oct 9, 2024
5b5e948
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
576c289
fix: update zbl example descriptor
anyangml Oct 10, 2024
5e34b9c
Merge branch 'devel' into feat/expose-linear-model
anyangml Oct 10, 2024
d3b3342
feat: add linear example
anyangml Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,15 @@ class LinearEnergyAtomicModel(BaseAtomicModel):
type_map : list[str]
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.
weights : Optional[Union[str,list[float]]]
Weights of the models. If str, must be `sum` or `mean`. If list, must be a list of float.
"""

def __init__(
self,
models: list[BaseAtomicModel],
type_map: list[str],
weights: Optional[Union[str, list[float]]] = "mean",
**kwargs,
):
super().__init__(type_map, **kwargs)
Expand Down Expand Up @@ -89,6 +92,16 @@ def __init__(
)
self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE) # pylint: disable=no-explicit-dtype

if isinstance(weights, str):
assert weights in ["sum", "mean"]
elif isinstance(weights, list):
assert len(weights) == len(models)
else:
raise ValueError(
f"'weights' must be a string ('sum' or 'mean') or a list of float of length {len(models)}."
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
self.weights = weights

anyangml marked this conversation as resolved.
Show resolved Hide resolved
def mixed_types(self) -> bool:
"""If true, the model
1. assumes total number of atoms aligned across frames;
Expand Down Expand Up @@ -336,11 +349,28 @@ def _compute_weight(
"""This should be a list of user defined weights that matches the number of models to be combined."""
nmodels = len(self.models)
nframes, nloc, _ = nlists_[0].shape
return [
torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE)
/ nmodels
for _ in range(nmodels)
]
if isinstance(self.weights, str):
if self.weights == "sum":
return [
torch.ones(
(nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE
)
for _ in range(nmodels)
]
elif self.weights == "mean":
return [
torch.ones(
(nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE
)
/ nmodels
for _ in range(nmodels)
]
elif isinstance(self.weights, list):
return [
torch.ones((nframes, nloc, 1), dtype=torch.float64, device=env.DEVICE)
* w
for w in self.weights
]

def get_dim_fparam(self) -> int:
"""Get the number (dimension) of frame parameters of this atomic model."""
Expand Down
61 changes: 61 additions & 0 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from .dos_model import (
DOSModel,
)
from .dp_linear_model import (
DPLinearModel,
)
from .dp_model import (
DPModelCommon,
)
Expand Down Expand Up @@ -105,6 +108,62 @@ def get_spin_model(model_params):
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_linear_model(model_params):
model_params = copy.deepcopy(model_params)
weights = model_params.get("weights", "mean")
list_of_models = []
ntypes = len(model_params["type_map"])
for sub_model_params in model_params["models"]:
anyangml marked this conversation as resolved.
Show resolved Hide resolved
if "descriptor" in sub_model_params:
# descriptor
sub_model_params["descriptor"]["ntypes"] = ntypes
sub_model_params["descriptor"]["type_map"] = copy.deepcopy(
model_params["type_map"]
)
descriptor = BaseDescriptor(**sub_model_params["descriptor"])
# fitting
fitting_net = sub_model_params.get("fitting_net", {})
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["type_map"] = copy.deepcopy(sub_model_params["type_map"])
fitting_net["mixed_types"] = descriptor.mixed_types()
if fitting_net["type"] in ["dipole", "polar"]:
fitting_net["embedding_width"] = descriptor.get_dim_emb()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = BaseFitting(**fitting_net)
list_of_models.append(
DPAtomicModel(descriptor, fitting, type_map=model_params["type_map"])
)

else: # must be pairtab
assert (
"type" in sub_model_params and sub_model_params["type"] == "pairtab"
), "Sub-models in LinearEnergyModel must be a DPModel or a PairTable Model"
list_of_models.append(
PairTabAtomicModel(
sub_model_params["tab_file"],
sub_model_params["rcut"],
sub_model_params["sel"],
type_map=model_params["type_map"],
)
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved

atom_exclude_types = model_params.get("atom_exclude_types", [])
pair_exclude_types = model_params.get("pair_exclude_types", [])
return DPLinearModel(
models=list_of_models,
type_map=model_params["type_map"],
weights=weights,
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)

anyangml marked this conversation as resolved.
Show resolved Hide resolved

def get_zbl_model(model_params):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
Expand Down Expand Up @@ -247,6 +306,8 @@ def get_model(model_params):
return get_zbl_model(model_params)
else:
return get_standard_model(model_params)
elif model_type == "linear_ener":
return get_linear_model(model_params)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
else:
return BaseModel.get_class_by_type(model_type).get_model(model_params)

Expand Down
160 changes: 160 additions & 0 deletions deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)
from typing import (
Optional,
)

import torch

from deepmd.pt.model.atomic_model import (
LinearEnergyAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .dp_model import (
DPModelCommon,
)
from .make_model import (
make_model,
)

DPLinearModel_ = make_model(LinearEnergyAtomicModel)
anyangml marked this conversation as resolved.
Show resolved Hide resolved


@BaseModel.register("linear_ener")
class DPLinearModel(DPLinearModel_):
model_type = "ener"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def translated_output_def(self):
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_energy": deepcopy(out_def_data["energy"]),
"energy": deepcopy(out_def_data["energy_redu"]),
}
if self.do_grad_r("energy"):
output_def["force"] = deepcopy(out_def_data["energy_derv_r"])
output_def["force"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
if "mask" in out_def_data:
output_def["mask"] = deepcopy(out_def_data["mask"])
return output_def

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)

model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
else:
model_predict["force"] = model_ret["dforce"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)

model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy"):
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
else:
assert model_ret["dforce"] is not None
model_predict["dforce"] = model_ret["dforce"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return model_predict
anyangml marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statictics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(
train_data, type_map, local_jdata["dpmodel"]
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
return local_jdata_cpy, min_nbor_dist
2 changes: 1 addition & 1 deletion examples/water/zbl/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"H"
],
"descriptor": {
"type": "se_e2_a",
"type": "se_atten",
anyangml marked this conversation as resolved.
Show resolved Hide resolved
anyangml marked this conversation as resolved.
Show resolved Hide resolved
"sel": [
46,
92
Expand Down
Loading