Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor DeepEval #3213

Merged
merged 25 commits into from
Feb 8, 2024
Merged
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e713910
refactor DeepEval
njzjz Feb 1, 2024
b428fc2
fix cycle import
njzjz Feb 1, 2024
eb0b7d2
fix errors
njzjz Feb 1, 2024
02b3a6d
fix PT tests
njzjz Feb 1, 2024
4fa8754
fix FrozenModel
njzjz Feb 1, 2024
6026ab7
add docs
njzjz Feb 1, 2024
28c5c92
Merge branch 'devel' into deepeval-refactor
njzjz Feb 2, 2024
c262ef6
rename .model_format to .dpmodel
njzjz Feb 2, 2024
4fe9511
rename DeepEvalBase to DeepEvalBackend as the meaning of base is unclear
njzjz Feb 2, 2024
e53b476
Merge remote-tracking branch 'origin/devel' into deepeval-refactor
njzjz Feb 2, 2024
5b16abf
ckean OutputVariableDef
njzjz Feb 2, 2024
51b5e7b
make atom_types to be int32
njzjz Feb 3, 2024
3516792
add docs
njzjz Feb 3, 2024
3bc7cb5
pass list[OutputVariableDef] instead of bool
njzjz Feb 4, 2024
acfd367
fix py38 compatibility
njzjz Feb 4, 2024
db6ac0d
Merge branch 'devel' into deepeval-refactor
njzjz Feb 5, 2024
be4e42f
use consistent output name for different backends
njzjz Feb 5, 2024
199147c
remove the commented out code
njzjz Feb 5, 2024
355c671
Merge branch 'devel' into deepeval-refactor
njzjz Feb 6, 2024
8686fa1
use OutputVariableCategory to check odef
njzjz Feb 6, 2024
c0a4251
fix typo
njzjz Feb 6, 2024
d7bf113
improve docs and type hints
njzjz Feb 8, 2024
acc85fb
Merge branch 'devel' into deepeval-refactor
njzjz Feb 8, 2024
295f4d7
change the OutputVariableDef argument from differentiable to r_differ…
njzjz Feb 8, 2024
6b104e3
fix the virial shape since DERV_C changed from (3,3) to (9,)
njzjz Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make atom_types to be int32
Co-authored-by: Han Wang <[email protected]>
Signed-off-by: Jinzhe Zeng <[email protected]>
njzjz and wanghan-iapcm authored Feb 3, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 51b5e7b75a7cdf7d4c4220a01eea16aa1702843e
2 changes: 1 addition & 1 deletion deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@
)

if TYPE_CHECKING:
import ase.neighborlist

Check warning on line 32 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L32

Added line #L32 was not covered by tests


class DeepEvalBackend(ABC):
@@ -64,22 +64,22 @@
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
) -> None:
pass

Check warning on line 67 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L67

Added line #L67 was not covered by tests

def __new__(cls, model_file: str, *args, **kwargs):
if cls is DeepEvalBackend:
backend = detect_backend(model_file)
if backend == DPBackend.TensorFlow:
from deepmd.tf.infer.deep_eval import DeepEval as DeepEvalTF

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.tf.infer.deep_eval
begins an import cycle.

return super().__new__(DeepEvalTF)
elif backend == DPBackend.PyTorch:
from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.infer.deep_eval
begins an import cycle.

return super().__new__(DeepEvalPT)
else:
raise NotImplementedError("Unsupported backend: " + str(backend))
return super().__new__(cls)

Check warning on line 82 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L81-L82

Added lines #L81 - L82 were not covered by tests

@abstractmethod
def eval(
@@ -198,7 +198,7 @@
descriptor
Descriptors.
"""
raise NotImplementedError

Check warning on line 201 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L201

Added line #L201 was not covered by tests

def eval_typeebd(self) -> np.ndarray:
"""Evaluate output of type embedding network by using this model.
@@ -215,7 +215,7 @@
KeyError
If the model does not enable type embedding.
"""
raise NotImplementedError

Check warning on line 218 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L218

Added line #L218 was not covered by tests

def _check_distinguished_types(self, atom_types: np.ndarray) -> bool:
"""Check if atom types of each frame."""
@@ -232,11 +232,11 @@

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
raise NotImplementedError

Check warning on line 235 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L235

Added line #L235 was not covered by tests

def get_has_efield(self):
"""Check if the model has efield."""
return False

Check warning on line 239 in deepmd/infer/deep_eval.py

Codecov / codecov/patch

deepmd/infer/deep_eval.py#L239

Added line #L239 was not covered by tests

@abstractmethod
def get_ntypes_spin(self) -> int:
@@ -441,7 +441,7 @@
coords = np.array(coords)
if cells is not None:
cells = np.array(cells)
atom_types = np.array(atom_types)
atom_types = np.array(atom_types, dtype=np.int32)
if fparam is not None:
fparam = np.array(fparam)
if aparam is not None:
Loading