Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: breaking: backend indepdent definition for dp model #3208

Merged
merged 32 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ba62e99
add chatgpt translated code
Jan 30, 2024
d51f044
fix bugs of nlist and add uts
Jan 31, 2024
4ed1a5e
fix wrapper class name
Jan 31, 2024
25045a1
support int types in type dict
Jan 31, 2024
1f5378b
fix descriptor interfaces (abs class should be introduced later)
Jan 31, 2024
3a029e0
add missing file
Jan 31, 2024
2e4ef50
add dp model def for atomic model
Jan 31, 2024
8c82620
refactor torch atomic model. implement serialize and deserialize. add…
Jan 31, 2024
27f14cd
dp model format for dp models.
Jan 31, 2024
fc3cc8e
dp model format for dp models. add missing files...
Jan 31, 2024
515a724
torch support for the dp model format
Jan 31, 2024
9ca6de3
add ut for the open boundary condition. fix bugs
Jan 31, 2024
7eb1637
support fparam and aparam for atomic model and model. add doc str for…
Jan 31, 2024
40042fa
clean up unused code
Jan 31, 2024
d498a23
fix bugs
Jan 31, 2024
3c84e07
add doc strings
Jan 31, 2024
0c1913b
rm unused code
Jan 31, 2024
96bbb45
fix type hint
Jan 31, 2024
c0a5a56
add format nlist to model
Feb 1, 2024
553ef53
Merge remote-tracking branch 'upstream/devel' into mdfmt-model-1
Feb 1, 2024
f54c724
add place holders to PairTab. fix test: asserting equal between floats
Feb 1, 2024
600349b
forward lower returns reduced virial
Feb 1, 2024
9dd10c9
fix ut
Feb 1, 2024
2ca139d
rm unused vars
Feb 1, 2024
927007c
do not raise, use pass
Feb 1, 2024
6f2d3ac
model_format -> dpmodel, changed module path
Feb 1, 2024
a944441
fix type of the name
Feb 1, 2024
9701ad9
rearrange the def of atomic model. provide base class for fitting. re…
Feb 1, 2024
f412bf7
add base abstract class for descriptor. also fixes get_ntype->get_ntypes
Feb 1, 2024
87fabfb
add distinguish types in base descriptor. fix docstrings
Feb 1, 2024
171e251
possible to change forward name in make base atomic model. fix debug …
Feb 2, 2024
7be7255
fix typo
Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
PRECISION_DICT,
NativeOP,
)
from .dp_atomic_model import (
DPAtomicModel,
)
from .dp_model import (
DPModel,
)
from .env_mat import (
EnvMat,
)
Expand Down Expand Up @@ -37,6 +43,8 @@
)

__all__ = [
"DPModel",
"DPAtomicModel",
"InvarFitting",
"DescrptSeA",
"EnvMat",
Expand Down
58 changes: 58 additions & 0 deletions deepmd/model_format/atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Dict,
List,
Optional,
)

from .output_def import (
FittingOutputDef,
)


def make_base_atomic_model(T_Tensor):
class BAM(ABC):
"""Base Atomic Model provides the interfaces of an atomic model."""

@abstractmethod
def get_fitting_output_def(self) -> FittingOutputDef:
raise NotImplementedError

Check warning on line 24 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L24

Added line #L24 was not covered by tests

@abstractmethod
def get_rcut(self) -> float:
raise NotImplementedError

Check warning on line 28 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L28

Added line #L28 was not covered by tests

@abstractmethod
def get_sel(self) -> List[int]:
raise NotImplementedError

Check warning on line 32 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L32

Added line #L32 was not covered by tests

@abstractmethod
def distinguish_types(self) -> bool:
raise NotImplementedError

Check warning on line 36 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L36

Added line #L36 was not covered by tests

@abstractmethod
def forward_atomic(
self,
extended_coord: T_Tensor,
extended_atype: T_Tensor,
nlist: T_Tensor,
mapping: Optional[T_Tensor] = None,
fparam: Optional[T_Tensor] = None,
aparam: Optional[T_Tensor] = None,
) -> Dict[str, T_Tensor]:
raise NotImplementedError

Check warning on line 48 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L48

Added line #L48 was not covered by tests

@abstractmethod
def serialize(self) -> dict:
raise NotImplementedError

Check warning on line 52 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L52

Added line #L52 was not covered by tests

@abstractclassmethod
def deserialize(cls):
raise NotImplementedError

Check warning on line 56 in deepmd/model_format/atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model_format/atomic_model.py#L56

Added line #L56 was not covered by tests

return BAM
8 changes: 8 additions & 0 deletions deepmd/model_format/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from .atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(np.ndarray)
2 changes: 2 additions & 0 deletions deepmd/model_format/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
"half": np.float16,
"single": np.float32,
"double": np.float64,
"int32": np.int32,
"int64": np.int64,
}
DEFAULT_PRECISION = "float64"

Expand Down
134 changes: 134 additions & 0 deletions deepmd/model_format/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from typing import (
Dict,
List,
Optional,
)

import numpy as np

from .base_atomic_model import (
BaseAtomicModel,
)
from .fitting import InvarFitting # noqa # TODO: should import all fittings!

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'InvarFitting' is not used.
from .output_def import (
FittingOutputDef,
)
from .se_e2_a import DescrptSeA # noqa # TODO: should import all descriptors!

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'DescrptSeA' is not used.


class DPAtomicModel(BaseAtomicModel):
"""Model give atomic prediction of some physical property.

Parameters
----------
descriptor
Descriptor
fitting_net
Fitting net
type_map
Mapping atom type to the name (str) of the type.
For example `type_map[1]` gives the name of the type 1.

"""

def __init__(
self,
descriptor,
fitting,
type_map: Optional[List[str]] = None,
):
super().__init__()
self.type_map = type_map
self.descriptor = descriptor
self.fitting = fitting

def get_fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return self.fitting.output_def()

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return self.descriptor.get_rcut()

def get_sel(self) -> List[int]:
"""Get the neighbor selection."""
return self.descriptor.get_sel()

def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return self.descriptor.distinguish_types()

def forward_atomic(
self,
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
mapping: Optional[np.ndarray] = None,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Models' atomic predictions.

Parameters
----------
extended_coord
coodinates in extended region
extended_atype
atomic type in extended region
nlist
neighbor list. nf x nloc x nsel
mapping
mapps the extended indices to local indices. nf x nall
fparam
frame parameter. nf x ndf
aparam
atomic parameter. nf x nloc x nda

Returns
-------
result_dict
the result dict, defined by the `FittingOutputDef`.

"""
nframes, nloc, nnei = nlist.shape
atype = extended_atype[:, :nloc]
descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
)
ret = self.fitting(
descriptor,
atype,
gr=rot_mat,
g2=g2,
h2=h2,
fparam=fparam,
aparam=aparam,
)
return ret

def serialize(self) -> dict:
return {
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj
9 changes: 9 additions & 0 deletions deepmd/model_format/dp_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .dp_atomic_model import (
DPAtomicModel,
)
from .make_model import (
make_model,
)

DPModel = make_model(DPAtomicModel)
3 changes: 2 additions & 1 deletion deepmd/model_format/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def call(
"which is not consistent with {self.numb_fparam}.",
)
fparam = (fparam - self.fparam_avg) * self.fparam_inv_std
fparam = np.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1])
fparam = np.tile(fparam.reshape([nf, 1, self.numb_fparam]), [1, nloc, 1])
xx = np.concatenate(
[xx, fparam],
axis=-1,
Expand All @@ -333,6 +333,7 @@ def call(
"get an input aparam of dim {aparam.shape[-1]}, ",
"which is not consistent with {self.numb_aparam}.",
)
aparam = aparam.reshape([nf, nloc, self.numb_aparam])
aparam = (aparam - self.aparam_avg) * self.aparam_inv_std
xx = np.concatenate(
[xx, aparam],
Expand Down
Loading
Loading