Skip to content

Commit

Permalink
Merge branch 'devel' of https://github.com/deepmodeling/deepmd-kit in…
Browse files Browse the repository at this point in the history
…to devel
  • Loading branch information
CaRoLZhangxy committed Feb 26, 2024
2 parents 93dacb6 + 261c802 commit eda885a
Show file tree
Hide file tree
Showing 72 changed files with 1,948 additions and 190 deletions.
26 changes: 26 additions & 0 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class Feature(Flag):
"""Support Deep Eval backend."""
NEIGHBOR_STAT = auto()
"""Support neighbor statistics."""
IO = auto()
"""Support IO hook."""

name: ClassVar[str] = "Unknown"
"""The formal name of the backend.
Expand Down Expand Up @@ -199,3 +201,27 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
The neighbor statistics of the backend.
"""
pass

@property
@abstractmethod
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
pass

@property
@abstractmethod
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
pass
40 changes: 38 additions & 2 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class DPModelBackend(Backend):

name = "DPModel"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT
features: ClassVar[Backend.Feature] = (
Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".dp"]
"""The suffixes of the backend."""
Expand Down Expand Up @@ -68,7 +70,11 @@ def deep_eval(self) -> Type["DeepEvalBackend"]:
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")
from deepmd.dpmodel.infer.deep_eval import (
DeepEval,
)

return DeepEval

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
Expand All @@ -84,3 +90,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
load_dp_model,
)

return load_dp_model

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
save_dp_model,
)

return save_dp_model
31 changes: 31 additions & 0 deletions deepmd/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TensorFlowBackend(Backend):
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".pth", ".pt"]
Expand Down Expand Up @@ -93,3 +94,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.pt.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
31 changes: 31 additions & 0 deletions deepmd/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TensorFlowBackend(Backend):
Backend.Feature.ENTRY_POINT
| Backend.Feature.DEEP_EVAL
| Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = [".pb"]
Expand Down Expand Up @@ -102,3 +103,33 @@ def neighbor_stat(self) -> Type["NeighborStat"]:
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
"""The serialize hook to convert the model file to a dictionary.
Returns
-------
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.tf.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
"""The deserialize hook to convert the dictionary to a model file.
Returns
-------
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.tf.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
21 changes: 21 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,24 @@ def get_hash(obj) -> str:
object to hash
"""
return sha1(json.dumps(obj).encode("utf-8")).hexdigest()


def j_get_type(data: dict, class_name: str = "object") -> str:
"""Get the type from the data.
Parameters
----------
data : dict
the data
class_name : str, optional
the name of the class for error message, by default "object"
Returns
-------
str
the type
"""
try:
return data["type"]
except KeyError as e:
raise KeyError(f"the type of the {class_name} should be set by `type`") from e
20 changes: 6 additions & 14 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from typing import (
Dict,
List,
Expand All @@ -9,12 +8,11 @@

import numpy as np

from deepmd.dpmodel.descriptor import ( # noqa # TODO: should import all descriptors!
DescrptSeA,
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
Expand Down Expand Up @@ -135,19 +133,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import numpy as np

from .make_base_descriptor import (
Expand Down
66 changes: 61 additions & 5 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (
Plugin,
)


def make_base_descriptor(
Expand All @@ -33,6 +40,42 @@ def make_base_descriptor(
class BD(ABC):
"""Base descriptor provides the interfaces of descriptor."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
Descriptor
the registered descriptor
Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return BD.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]:
if descrpt_type in BD.__plugins.plugins:
return BD.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -95,10 +138,23 @@ def serialize(self) -> dict:
"""Serialize the obj to dict."""
pass

@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
@classmethod
def deserialize(cls, data: dict) -> "BD":
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
BD
The deserialized descriptor
"""
if cls is BD:
return BD.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

setattr(BD, fwd_method_name, BD.fwd)
delattr(BD, "fwd")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)


@BaseDescriptor.register("se_e2_a")
class DescrptSeA(NativeOP, BaseDescriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down Expand Up @@ -313,6 +314,8 @@ def call(
def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
"@class": "Descriptor",
"type": "se_e2_a",
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand All @@ -339,6 +342,8 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict) -> "DescrptSeA":
"""Deserialize from dict."""
data = copy.deepcopy(data)
data.pop("@class", None)
data.pop("type", None)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
Loading

0 comments on commit eda885a

Please sign in to comment.