Skip to content

Commit

Permalink
rearrange the def of atomic model. provide base class for fitting. re…
Browse files Browse the repository at this point in the history
…moved the task base.
  • Loading branch information
Han Wang committed Feb 1, 2024
1 parent a944441 commit 9701ad9
Show file tree
Hide file tree
Showing 23 changed files with 161 additions and 101 deletions.
6 changes: 5 additions & 1 deletion deepmd/dpmodel/fitting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .fitting import (
from .invar_fitting import (
InvarFitting,
)
from .make_base_fitting import (
make_base_fitting,
)

__all__ = [
"InvarFitting",
"make_base_fitting",
]
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/base_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from .make_base_fitting import (
make_base_fitting,
)

BaseFitting = make_base_fitting(np.ndarray, "call")
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
from typing import (
Any,
Dict,
List,
Optional,
)
Expand All @@ -22,10 +23,14 @@
NetworkCollection,
)

from .base_fitting import (
BaseFitting,
)


@fitting_check_output
class InvarFitting(NativeOP):
r"""Fitting the energy (or a porperty of `dim_out`) of the system. The force and the virial can also be trained.
class InvarFitting(NativeOP, BaseFitting):
r"""Fitting the energy (or a rotationally invariant porperty of `dim_out`) of the system. The force and the virial can also be trained.
Lets take the energy fitting task as an example.
The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`:
Expand Down Expand Up @@ -279,7 +284,7 @@ def call(
h2: Optional[np.array] = None,
fparam: Optional[np.array] = None,
aparam: Optional[np.array] = None,
):
) -> Dict[str, np.array]:
"""Calculate the fitting.
Parameters
Expand Down
54 changes: 54 additions & 0 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Dict,
Optional,
)

from deepmd.dpmodel.output_def import (
FittingOutputDef,
)


def make_base_fitting(
T_Tensor,
FWD_Method: str = "call",
):
"""Make the base class for the fitting."""

class BF(ABC):
"""Base fitting provides the interfaces of fitting net."""

@abstractmethod
def output_def(self) -> FittingOutputDef:
pass

Check warning on line 28 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L28

Added line #L28 was not covered by tests

@abstractmethod
def fwd(
self,
descriptor: T_Tensor,
atype: T_Tensor,
gr: Optional[T_Tensor] = None,
g2: Optional[T_Tensor] = None,
h2: Optional[T_Tensor] = None,
fparam: Optional[T_Tensor] = None,
aparam: Optional[T_Tensor] = None,
) -> Dict[str, T_Tensor]:
pass

Check warning on line 41 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L41

Added line #L41 was not covered by tests

@abstractmethod
def serialize(self) -> dict:
pass

Check warning on line 45 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L45

Added line #L45 was not covered by tests

@abstractclassmethod
def deserialize(cls):
pass

Check warning on line 49 in deepmd/dpmodel/fitting/make_base_fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/fitting/make_base_fitting.py#L49

Added line #L49 was not covered by tests

setattr(BF, FWD_Method, BF.fwd)
delattr(BF, "fwd")

return BF
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .dp_model import (
DPModel,
)
from .make_atomic_model import (
from .make_base_atomic_model import (
make_base_atomic_model,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np

from .make_atomic_model import (
from .make_base_atomic_model import (
make_base_atomic_model,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self.descriptor = descriptor
self.fitting = fitting

def get_fitting_output_def(self) -> FittingOutputDef:
def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return self.fitting.output_def()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BAM(ABC):
"""Base Atomic Model provides the interfaces of an atomic model."""

@abstractmethod
def get_fitting_output_def(self) -> FittingOutputDef:
def fitting_output_def(self) -> FittingOutputDef:
pass

Check warning on line 24 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L24

Added line #L24 was not covered by tests

@abstractmethod
Expand Down Expand Up @@ -55,4 +55,29 @@ def serialize(self) -> dict:
def deserialize(cls):
pass

Check warning on line 56 in deepmd/dpmodel/model/make_base_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/make_base_atomic_model.py#L56

Added line #L56 was not covered by tests

def do_grad(
self,
var_name: Optional[str] = None,
) -> bool:
"""Tell if the output variable `var_name` is differentiable.
if var_name is None, returns if any of the variable is differentiable.
"""
odef = self.fitting_output_def()
if var_name is None:
require: List[bool] = []
for vv in odef.keys():
require.append(self.do_grad_(vv))
return any(require)
else:
return self.do_grad_(var_name)

def do_grad_(
self,
var_name: str,
) -> bool:
"""Tell if the output variable `var_name` is differentiable."""
assert var_name is not None
return self.fitting_output_def()[var_name].differentiable

return BAM
8 changes: 4 additions & 4 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def __init__(
**kwargs,
)

def get_model_output_def(self):
def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.get_fitting_output_def())
return ModelOutputDef(self.fitting_output_def())

def call(
self,
Expand Down Expand Up @@ -125,7 +125,7 @@ def call(
)
model_predict = communicate_extended_output(
model_predict_lower,
self.get_model_output_def(),
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
Expand Down Expand Up @@ -182,7 +182,7 @@ def call_lower(
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.get_fitting_output_def(),
self.fitting_output_def(),
extended_coord,
do_atomic_virial=do_atomic_virial,
)
Expand Down
42 changes: 0 additions & 42 deletions deepmd/pt/model/model/atomic_model.py

This file was deleted.

9 changes: 9 additions & 0 deletions deepmd/pt/model/model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch

from deepmd.dpmodel.model import (
make_base_atomic_model,
)

BaseAtomicModel = make_base_atomic_model(torch.Tensor)
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
InvarFitting,
)

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'InvarFitting' is not used.

from .atomic_model import (
AtomicModel,
from .base_atomic_model import (
BaseAtomicModel,
)
from .model import (
BaseModel,
)


class DPAtomicModel(BaseModel, AtomicModel):
class DPAtomicModel(BaseModel, BaseAtomicModel):
"""Model give atomic prediction of some physical property.
Parameters
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
sampled=sampled,
)

def get_fitting_output_def(self) -> FittingOutputDef:
def fitting_output_def(self) -> FittingOutputDef:
"""Get the output def of the fitting net."""
return (
self.fitting_net.output_def()
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def __init__(
**kwargs,
)

def get_model_output_def(self):
def model_output_def(self):
"""Get the output def for the model."""
return ModelOutputDef(self.get_fitting_output_def())
return ModelOutputDef(self.fitting_output_def())

# cannot use the name forward. torch script does not work
def forward_common(
Expand Down Expand Up @@ -123,7 +123,7 @@ def forward_common(
)
model_predict = communicate_extended_output(
model_predict_lower,
self.get_model_output_def(),
self.model_output_def(),
mapping,
do_atomic_virial=do_atomic_virial,
)
Expand Down Expand Up @@ -176,7 +176,7 @@ def forward_common_lower(
)
model_predict = fit_output_to_model_output(
atomic_ret,
self.get_fitting_output_def(),
self.fitting_output_def(),
extended_coord,
do_atomic_virial=do_atomic_virial,
)
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/pair_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
PairTab,
)

from .atomic_model import (
AtomicModel,
from .base_atomic_model import (
BaseAtomicModel,
)


class PairTabModel(nn.Module, AtomicModel):
class PairTabModel(nn.Module, BaseAtomicModel):
"""Pairwise tabulation energy model.
This model can be used to tabulate the pairwise energy between atoms for either
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
else:
raise TypeError("sel must be int or list[int]")

def get_fitting_output_def(self) -> FittingOutputDef:
def fitting_output_def(self) -> FittingOutputDef:
return FittingOutputDef(
[
OutputVariableDef(
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .atten_lcc import (
FittingNetAttenLcc,
)
from .base_fitting import (
BaseFitting,
)
from .denoise import (
DenoiseNet,
)
Expand All @@ -15,9 +18,6 @@
from .fitting import (
Fitting,
)
from .task import (
TaskBaseMethod,
)
from .type_predict import (
TypePredictNet,
)
Expand All @@ -29,6 +29,6 @@
"EnergyFittingNet",
"EnergyFittingNetDirect",
"Fitting",
"TaskBaseMethod",
"BaseFitting",
"TypePredictNet",
]
6 changes: 3 additions & 3 deletions deepmd/pt/model/task/atten_lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
EnergyHead,
NodeTaskHead,
)
from deepmd.pt.model.task.task import (
TaskBaseMethod,
from deepmd.pt.model.task.fitting import (
Fitting,
)
from deepmd.pt.utils import (
env,
)


class FittingNetAttenLcc(TaskBaseMethod):
class FittingNetAttenLcc(Fitting):
def __init__(
self, embedding_width, bias_atom_e, pair_embed_dim, attention_heads, **kwargs
):
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/model/task/base_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import torch

from deepmd.dpmodel.fitting import (
make_base_fitting,
)

BaseFitting = make_base_fitting(torch.Tensor, "forward")
Loading

0 comments on commit 9701ad9

Please sign in to comment.