-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rearrange the def of atomic model. provide base class for fitting. re…
…moved the task base.
- Loading branch information
Han Wang
committed
Feb 1, 2024
1 parent
a944441
commit 9701ad9
Showing
23 changed files
with
161 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .fitting import ( | ||
from .invar_fitting import ( | ||
InvarFitting, | ||
) | ||
from .make_base_fitting import ( | ||
make_base_fitting, | ||
) | ||
|
||
__all__ = [ | ||
"InvarFitting", | ||
"make_base_fitting", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import numpy as np | ||
|
||
from .make_base_fitting import ( | ||
make_base_fitting, | ||
) | ||
|
||
BaseFitting = make_base_fitting(np.ndarray, "call") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from abc import ( | ||
ABC, | ||
abstractclassmethod, | ||
abstractmethod, | ||
) | ||
from typing import ( | ||
Dict, | ||
Optional, | ||
) | ||
|
||
from deepmd.dpmodel.output_def import ( | ||
FittingOutputDef, | ||
) | ||
|
||
|
||
def make_base_fitting( | ||
T_Tensor, | ||
FWD_Method: str = "call", | ||
): | ||
"""Make the base class for the fitting.""" | ||
|
||
class BF(ABC): | ||
"""Base fitting provides the interfaces of fitting net.""" | ||
|
||
@abstractmethod | ||
def output_def(self) -> FittingOutputDef: | ||
pass | ||
|
||
@abstractmethod | ||
def fwd( | ||
self, | ||
descriptor: T_Tensor, | ||
atype: T_Tensor, | ||
gr: Optional[T_Tensor] = None, | ||
g2: Optional[T_Tensor] = None, | ||
h2: Optional[T_Tensor] = None, | ||
fparam: Optional[T_Tensor] = None, | ||
aparam: Optional[T_Tensor] = None, | ||
) -> Dict[str, T_Tensor]: | ||
pass | ||
|
||
@abstractmethod | ||
def serialize(self) -> dict: | ||
pass | ||
|
||
@abstractclassmethod | ||
def deserialize(cls): | ||
pass | ||
|
||
setattr(BF, FWD_Method, BF.fwd) | ||
delattr(BF, "fwd") | ||
|
||
return BF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import torch | ||
|
||
from deepmd.dpmodel.model import ( | ||
make_base_atomic_model, | ||
) | ||
|
||
BaseAtomicModel = make_base_atomic_model(torch.Tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import torch | ||
|
||
from deepmd.dpmodel.fitting import ( | ||
make_base_fitting, | ||
) | ||
|
||
BaseFitting = make_base_fitting(torch.Tensor, "forward") |
Oops, something went wrong.