Skip to content

Commit

Permalink
refactor torch atomic model. implement serialize and deserialize. add…
Browse files Browse the repository at this point in the history
… UT for consistency
  • Loading branch information
Han Wang committed Jan 31, 2024
1 parent 2e4ef50 commit 8c82620
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 79 deletions.
4 changes: 2 additions & 2 deletions deepmd/model_format/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from .base_atomic_model import (
BaseAtomicModel,
)
from .fitting import InvarFitting # noqa # should import all fittings!
from .fitting import InvarFitting # noqa # TODO: should import all fittings!
from .output_def import (
FittingOutputDef,
)
from .se_e2_a import DescrptSeA # noqa # should import all descriptors!
from .se_e2_a import DescrptSeA # noqa # TODO: should import all descriptors!


class DPAtomicModel(BaseAtomicModel):
Expand Down
31 changes: 29 additions & 2 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy

from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.task import (
Fitting,
)

from .ener import (
EnergyModel,
)
Expand All @@ -8,9 +17,27 @@


def get_model(model_params, sampled=None):
model_params = copy.deepcopy(model_params)
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
descriptor = Descriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntype()
fitting_net["distinguish_types"] = descriptor.distinguish_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)

return EnergyModel(
descriptor=model_params["descriptor"],
fitting_net=model_params.get("fitting_net", None),
descriptor,
fitting,
type_map=model_params["type_map"],
type_embedding=model_params.get("type_embedding", None),
resuming=model_params.get("resuming", False),
Expand Down
38 changes: 4 additions & 34 deletions deepmd/pt/model/model/atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,19 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
Dict,
List,
Optional,
)

import torch

from deepmd.model_format import (
FittingOutputDef,
from deepmd.model_format.atomic_model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(torch.Tensor)

class AtomicModel(ABC):
@abstractmethod
def get_fitting_output_def(self) -> FittingOutputDef:
raise NotImplementedError

@abstractmethod
def get_rcut(self) -> float:
raise NotImplementedError

@abstractmethod
def get_sel(self) -> List[int]:
raise NotImplementedError

@abstractmethod
def distinguish_types(self) -> bool:
raise NotImplementedError

@abstractmethod
def forward_atomic(
self,
extended_coord,
extended_atype,
nlist,
mapping: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
) -> Dict[str, torch.Tensor]:
raise NotImplementedError

class AtomicModel(BaseAtomicModel):
def do_grad(
self,
var_name: Optional[str] = None,
Expand Down
76 changes: 40 additions & 36 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from typing import (
Dict,
List,
Expand All @@ -10,11 +12,11 @@
from deepmd.model_format import (
FittingOutputDef,
)
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
from deepmd.pt.model.descriptor.se_a import ( # noqa # TODO: should import all descriptors!!!
DescrptSeA,
)
from deepmd.pt.model.task import (
Fitting,
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
InvarFitting,
)

from .atomic_model import (
Expand Down Expand Up @@ -49,10 +51,11 @@ class DPAtomicModel(BaseModel, AtomicModel):
Sampled frames to compute the statistics.
"""

# I am enough with the shit interface!
def __init__(
self,
descriptor: dict,
fitting_net: dict,
descriptor,
fitting,
type_map: Optional[List[str]],
type_embedding: Optional[dict] = None,
resuming: bool = False,
Expand All @@ -62,26 +65,15 @@ def __init__(
**kwargs,
):
super().__init__()
# Descriptor + Type Embedding Net (Optional)
ntypes = len(type_map)
self.type_map = type_map
self.ntypes = ntypes
descriptor["ntypes"] = ntypes
self.combination = descriptor.get("combination", False)
if self.combination:
self.prefactor = descriptor.get("prefactor", [0.5, 0.5])
self.descriptor_type = descriptor["type"]

self.type_split = True
if self.descriptor_type not in ["se_e2_a"]:
self.type_split = False

self.descriptor = Descriptor(**descriptor)
self.descriptor = descriptor
self.rcut = self.descriptor.get_rcut()
self.sel = self.descriptor.get_sel()
self.split_nlist = False

self.fitting_net = fitting
# Statistics
fitting_net = None # TODO: hack!!! not sure if it is correct.
self.compute_or_load_stat(
fitting_net,
ntypes,
Expand All @@ -92,21 +84,6 @@ def __init__(
sampled=sampled,
)

fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = self.descriptor.get_ntype()
if self.descriptor_type in ["se_e2_a"]:
fitting_net["distinguish_types"] = True
else:
fitting_net["distinguish_types"] = False
fitting_net["embedding_width"] = self.descriptor.dim_out

self.grad_force = "direct" not in fitting_net["type"]
if not self.grad_force:
fitting_net["out_dim"] = self.descriptor.dim_emb
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
self.fitting_net = Fitting(**fitting_net)

def get_fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return (
Expand All @@ -125,7 +102,34 @@ def get_sel(self) -> List[int]:

def distinguish_types(self) -> bool:
"""If distinguish different types by sorting."""
return self.type_split
return self.descriptor.distinguish_types()

def serialize(self) -> dict:
return {
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
"descriptor_name": self.descriptor.__class__.__name__,
"fitting_name": self.fitting_net.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = getattr(
sys.modules[__name__], data["descriptor_name"]
).deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
# TODO: dirty hack to provide type_map and avoid data stat!!!
obj = cls(
descriptor_obj,
fitting_obj,
type_map=data["type_map"],
resuming=True,
)
return obj

def forward_atomic(
self,
Expand Down
4 changes: 0 additions & 4 deletions deepmd/pt/model/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def __init__(self):
"""Construct a basic model for different tasks."""
super().__init__()

def forward(self, *args, **kwargs):
"""Model output."""
raise NotImplementedError

def compute_or_load_stat(
self,
fitting_param,
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def __init__(
ntypes,
embedding_width,
neuron,
bias_atom_e,
bias_atom_e=None,
out_dim=1,
resnet_dt=True,
use_tebd=True,
Expand All @@ -418,6 +418,8 @@ def __init__(
self.dim_descrpt = embedding_width
self.use_tebd = use_tebd
self.out_dim = out_dim
if bias_atom_e is None:
bias_atom_e = np.zeros([self.ntypes])
if not use_tebd:
assert self.ntypes == len(bias_atom_e), "Element count mismatches!"
bias_atom_e = torch.tensor(bias_atom_e)
Expand Down
112 changes: 112 additions & 0 deletions source/tests/pt/test_dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np
import torch

from deepmd.model_format import DescrptSeA as DPDescrptSeA
from deepmd.model_format import DPAtomicModel as DPDPAtomicModel
from deepmd.model_format import InvarFitting as DPInvarFitting
from deepmd.pt.model.descriptor.se_a import (
DescrptSeA,
)
from deepmd.pt.model.model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.task.ener import (
InvarFitting,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
to_torch_tensor,
)

from .test_env_mat import (
TestCaseSingleFrameWithNlist,
)

dtype = env.GLOBAL_PT_FLOAT_PRECISION


class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist):
def setUp(self):
TestCaseSingleFrameWithNlist.setUp(self)

def test_self_consistency(self):
nf, nloc, nnei = self.nlist.shape
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE)
args = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
ret0 = md0.forward_atomic(*args)
ret1 = md1.forward_atomic(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
)

def test_dp_consistency(self):
rng = np.random.default_rng()
nf, nloc, nnei = self.nlist.shape
ds = DPDescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
)
ft = DPInvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
)
type_map = ["foo", "bar"]
md0 = DPDPAtomicModel(ds, ft, type_map=type_map)
md1 = DPAtomicModel.deserialize(md0.serialize()).to(env.DEVICE)
args0 = [self.coord_ext, self.atype_ext, self.nlist]
args1 = [
to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist]
]
ret0 = md0.forward_atomic(*args0)
ret1 = md1.forward_atomic(*args1)
np.testing.assert_allclose(
ret0["energy"],
to_numpy_array(ret1["energy"]),
)

def test_jit(self):
nf, nloc, nnei = self.nlist.shape
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
).to(env.DEVICE)
ft = InvarFitting(
"energy",
self.nt,
ds.get_dim_out(),
1,
distinguish_types=ds.distinguish_types(),
).to(env.DEVICE)
type_map = ["foo", "bar"]
# TODO: dirty hack to avoid data stat!!!
md0 = DPAtomicModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE)
torch.jit.script(md0)

0 comments on commit 8c82620

Please sign in to comment.