Skip to content

Commit

Permalink
feat: add dp
Browse files Browse the repository at this point in the history
  • Loading branch information
anyangml committed Mar 12, 2024
1 parent 1f8c74c commit f9b0b06
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 13 deletions.
81 changes: 81 additions & 0 deletions deepmd/dpmodel/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
)

from deepmd.dpmodel.common import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.invar_fitting import (
InvarFitting,
)

if TYPE_CHECKING:
from deepmd.dpmodel.fitting.general_fitting import (
GeneralFitting,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@InvarFitting.register("dos")
class DOSFittingNet(InvarFitting):
def __init__(
self,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
numb_aparam: int = 0,
numb_dos: int = 300,
rcond: Optional[float] = None,
trainable: Optional[List[bool]] = None,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
mixed_types: bool = False,
exclude_types: List[int] = [],
# not used
seed: Optional[int] = None,
):
super().__init__(
var_name="dos",
ntypes=ntypes,
dim_descrpt=dim_descrpt,
dim_out=numb_dos,
neuron=neuron,
resnet_dt=resnet_dt,
numb_fparam=numb_fparam,
numb_aparam=numb_aparam,
rcond=rcond,
trainable=trainable,
activation_function=activation_function,
precision=precision,
mixed_types=mixed_types,
exclude_types=exclude_types,
)

@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
data.pop("dim_out")
data.pop("tot_ener_zero")
data.pop("layer_name")
data.pop("use_aparam_as_mask")
data.pop("spin")
data.pop("atom_ener")
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "dos",
}
80 changes: 80 additions & 0 deletions deepmd/pt/model/model/dos_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Dict,
Optional,
)

import torch

from .dp_model import (
DPModel,
)


class DOSModel(DPModel):
model_type = "dos"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

def forward(
self,
coord,
atype,
box: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
model_ret = self.forward_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]

if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
else:
model_predict = model_ret
model_predict["updated_coord"] += coord
return model_predict

@torch.jit.export
def forward_lower(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
):
model_ret = self.forward_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
if self.get_fitting_net() is not None:
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["energy_redu"]

else:
model_predict = model_ret
return model_predict
4 changes: 3 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __new__(
from deepmd.pt.model.model.polar_model import (
PolarModel,
)
from deepmd.pt.model.model.dos_model
from deepmd.pt.model.model.dos_model import (
DOSModel,
)

if atomic_model_ is not None:
fitting = atomic_model_.fitting_net
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torch

from deepmd.pt.model.task.fitting import (
Fitting,
)
from deepmd.pt.model.task.ener import (
InvarFitting,
)
from deepmd.pt.utils import (
Expand Down Expand Up @@ -64,11 +67,12 @@ def __init__(
rcond=rcond,
seed=seed,
exclude_types=exclude_types,
trainable=trainable,
**kwargs,
)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
def deserialize(cls, data: dict) -> "InvarFitting":
data = copy.deepcopy(data)
check_version_compatibility(data.pop("@version", 1), 1, 1)
data.pop("var_name")
Expand Down
108 changes: 97 additions & 11 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from deepmd.utils.out_stat import (
compute_stats_from_redu,
)
from deepmd.utils.version import (
check_version_compatibility,
)

log = logging.getLogger(__name__)

Expand All @@ -57,8 +60,10 @@ class DOSFitting(Fitting):
Parameters
----------
descrpt
The descrptor :math:`\mathcal{D}`
ntypes
The ntypes of the descrptor :math:`\mathcal{D}`
dim_descrpt
The dimension of the descrptor :math:`\mathcal{D}`
neuron
Number of neurons :math:`N` in each hidden layer of the fitting net
resnet_dt
Expand Down Expand Up @@ -94,7 +99,8 @@ class DOSFitting(Fitting):

def __init__(
self,
descrpt: tf.Tensor,
ntypes: int,
dim_descrpt: int,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
Expand All @@ -112,8 +118,8 @@ def __init__(
) -> None:
"""Constructor."""
# model param
self.ntypes = descrpt.get_ntypes()
self.dim_descrpt = descrpt.get_dim_out()
self.ntypes = ntypes
self.dim_descrpt = dim_descrpt
self.use_aparam_as_mask = use_aparam_as_mask

self.numb_fparam = numb_fparam
Expand All @@ -127,6 +133,7 @@ def __init__(
self.seed = seed
self.uniform_seed = uniform_seed
self.seed_shift = one_layer_rand_seed_shift()
self.activation_function = activation_function
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
self.trainable = trainable
Expand All @@ -145,16 +152,16 @@ def __init__(
add_data_requirement(
"fparam", self.numb_fparam, atomic=False, must=True, high_prec=False
)
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
self.fparam_avg = None
self.fparam_std = None
self.fparam_inv_std = None
if self.numb_aparam > 0:
add_data_requirement(
"aparam", self.numb_aparam, atomic=True, must=True, high_prec=False
)
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None
self.aparam_avg = None
self.aparam_std = None
self.aparam_inv_std = None

self.fitting_net_variables = None
self.mixed_prec = None
Expand Down Expand Up @@ -641,3 +648,82 @@ def get_loss(self, loss: dict, lr) -> Loss:
return DOSLoss(
**loss, starter_learning_rate=lr.start_lr(), numb_dos=self.get_numb_dos()
)

@classmethod
def deserialize(cls, data: dict, suffix: str = ""):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
fitting = cls(**data)
fitting.fitting_net_variables = cls.deserialize_network(
data["nets"],
suffix=suffix,
)
fitting.bias_dos = data["@variables"]["bias_dos"]
if fitting.numb_fparam > 0:
fitting.fparam_avg = data["@variables"]["fparam_avg"]
fitting.fparam_inv_std = data["@variables"]["fparam_inv_std"]
if fitting.numb_aparam > 0:
fitting.aparam_avg = data["@variables"]["aparam_avg"]
fitting.aparam_inv_std = data["@variables"]["aparam_inv_std"]
return fitting

def serialize(self, suffix: str = "") -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {
"@class": "Fitting",
"type": "dos",
"@version": 1,
"var_name": "dos",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
# very bad design: type embedding is not passed to the class
# TODO: refactor the class
"mixed_types": False,
"dim_out": 1,
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"rcond": self.rcond,
"trainable": self.trainable,
"activation_function": self.activation_function,
"precision": self.fitting_precision.name,
"exclude_types": [],
"nets": self.serialize_network(
ntypes=self.ntypes,
# TODO: consider type embeddings
ndim=1,
in_dim=self.dim_descrpt + self.numb_fparam + self.numb_aparam,
neuron=self.n_neuron,
activation_function=self.activation_function,
resnet_dt=self.resnet_dt,
variables=self.fitting_net_variables,
suffix=suffix,
),
"@variables": {
"bias_dos": self.bias_dos,
"fparam_avg": self.fparam_avg,
"fparam_inv_std": self.fparam_inv_std,
"aparam_avg": self.aparam_avg,
"aparam_inv_std": self.aparam_inv_std,
},
}
return data
Loading

0 comments on commit f9b0b06

Please sign in to comment.