Skip to content

Commit

Permalink
add BaseModel; store type in serialization (deepmodeling#3335)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
njzjz and wanghan-iapcm authored Feb 27, 2024
1 parent c538d04 commit 854d998
Show file tree
Hide file tree
Showing 17 changed files with 370 additions and 18 deletions.
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def forward_atomic(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -138,6 +140,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
Expand Down Expand Up @@ -182,12 +183,17 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
Expand Down Expand Up @@ -263,6 +269,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand All @@ -271,6 +279,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def do_grad_(self, var_name: str, base: str) -> bool:
return self.fitting_output_def()[var_name].c_differentiable
return self.fitting_output_def()[var_name].r_differentiable

def get_model_def_script(self) -> str:
# TODO: implement this method; saved to model
raise NotImplementedError

setattr(BAM, fwd_method_name, BAM.fwd)
delattr(BAM, "fwd")

Expand Down
12 changes: 11 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -105,10 +106,19 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
Expand Down
6 changes: 3 additions & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.dp = BaseModel.deserialize(model_data["model"])
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
158 changes: 158 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import inspect
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
List,
Type,
)

from deepmd.utils.plugin import (
make_plugin_registry,
)


def make_base_model() -> Type[object]:
class BaseBaseModel(ABC, make_plugin_registry("model")):
"""Base class for final exported model that will be directly used for inference.
The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.
This class is backend-indepedent.
See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

def __new__(cls, *args, **kwargs):
if inspect.isabstract(cls):
cls = cls.get_class_by_type(kwargs.get("type", "standard"))
return super().__new__(cls)

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.
Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.
Returns
-------
Any
The output of the inference.
"""
pass

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.
Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).
If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""

@abstractmethod
def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
pass

@classmethod
def deserialize(cls, data: dict) -> "BaseBaseModel":
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
BaseModel
The deserialized model
"""
if inspect.isabstract(cls):
return cls.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

model_def_script: str

@abstractmethod
def get_model_def_script(self) -> str:
"""Get the model definition script."""
pass

@abstractmethod
def get_nnei(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
# for C++ interface
pass

@abstractmethod
def get_nsel(self) -> int:
"""Returns the total number of selected neighboring atoms in the cut-off radius."""
pass

return BaseBaseModel


class BaseModel(make_base_model()):
"""Base class for final exported model that will be directly used for inference.
The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.
This class is for the DPModel backend.
See Also
--------
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
pass
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mixed_types(self) -> bool:

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}
Expand Down Expand Up @@ -299,6 +301,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,13 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
Expand Down
38 changes: 37 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,45 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
)
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .make_model import (
make_model,
)

DPModel = make_model(DPAtomicModel)

@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
def __new__(cls, descriptor, fitting, *args, **kwargs):
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
from deepmd.pt.model.model.polar_model import (
PolarModel,
)

# according to the fitting network to decide the type of the model
if cls is DPModel:
# map fitting to model
if isinstance(fitting, EnergyFittingNet):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)
Loading

0 comments on commit 854d998

Please sign in to comment.