Skip to content

Commit

Permalink
chore: improve type anotations in deepmd.infer
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 17, 2024
1 parent d62a41f commit 81ec3b6
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 43 deletions.
8 changes: 4 additions & 4 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -161,12 +161,12 @@ def get_ntypes_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down
9 changes: 1 addition & 8 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@
)

if TYPE_CHECKING:
from deepmd.tf.infer import (
DeepDipole,
DeepDOS,
DeepPolar,
DeepPot,
DeepWFC,
)
from deepmd.tf.infer.deep_tensor import (
from deepmd.infer.deep_tensor import (

Check warning on line 49 in deepmd/entrypoints/test.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/test.py#L49

Added line #L49 was not covered by tests
DeepTensor,
)

Expand Down
3 changes: 1 addition & 2 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
Expand Down Expand Up @@ -70,7 +69,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
23 changes: 12 additions & 11 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -82,10 +83,10 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
pass

Expand All @@ -99,12 +100,12 @@ def __new__(cls, model_file: str, *args, **kwargs):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down Expand Up @@ -166,13 +167,13 @@ def get_dim_aparam(self) -> int:
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.
Expand Down Expand Up @@ -246,11 +247,11 @@ def _check_mixed_types(self, atom_types: np.ndarray) -> bool:
# assume mixed_types if there are virtual types, even when
# the atom types of all frames are the same
return False
return np.all(np.equal(atom_types, atom_types[0]))
return np.all(np.equal(atom_types, atom_types[0])).item()

@property
@abstractmethod
def model_type(self) -> "DeepEval":
def model_type(self) -> Type["DeepEval"]:
"""The the evaluator of the model type."""

@abstractmethod
Expand Down Expand Up @@ -316,10 +317,10 @@ def __new__(cls, model_file: str, *args, **kwargs):
def __init__(
self,
model_file: str,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> None:
self.deep_eval = DeepEvalBackend(
model_file,
Expand Down Expand Up @@ -387,7 +388,7 @@ def eval_descriptor(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.
Expand Down
2 changes: 1 addition & 1 deletion deepmd/infer/deep_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: dict,
**kwargs,
) -> np.ndarray:
"""Evaluate the model.
Expand Down
47 changes: 45 additions & 2 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Tuple,
Union,
overload,
)

import numpy as np
Expand Down Expand Up @@ -89,6 +90,48 @@ def output_def_mag(self) -> ModelOutputDef:
)
)

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[True],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
pass

Check warning on line 105 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L105

Added line #L105 was not covered by tests

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: Literal[False],
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
pass

Check warning on line 119 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L119

Added line #L119 was not covered by tests

@overload
def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: Union[List[int], np.ndarray],
atomic: bool,
fparam: Optional[np.ndarray],
aparam: Optional[np.ndarray],
mixed_type: bool,
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
pass

Check warning on line 133 in deepmd/infer/deep_pot.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_pot.py#L133

Added line #L133 was not covered by tests

def eval(
self,
coords: np.ndarray,
Expand All @@ -98,7 +141,7 @@ def eval(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Tuple[np.ndarray, ...]:
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.
Expand Down
14 changes: 11 additions & 3 deletions deepmd/infer/model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: Literal[False] = False,
atomic: Literal[False] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ...


Expand All @@ -37,11 +37,19 @@ def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
*,
atomic: Literal[True],
atomic: Literal[True] = ...,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ...


@overload
def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
relative: Optional[float] = None,
atomic: bool = False,
) -> Tuple[np.ndarray, ...]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


def calc_model_devi_f(
fs: np.ndarray,
real_f: Optional[np.ndarray] = None,
Expand Down
11 changes: 6 additions & 5 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -87,11 +88,11 @@ def __init__(
self,
model_file: str,
output_def: ModelOutputDef,
*args: List[Any],
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
):
self.output_def = output_def
self.model_path = model_file
Expand Down Expand Up @@ -165,7 +166,7 @@ def get_dim_aparam(self) -> int:
return self.dp.model["Default"].get_dim_aparam()

@property
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""The the evaluator of the model type."""
model_output_type = self.dp.model["Default"].model_output_type()
if "energy" in model_output_type:
Expand Down Expand Up @@ -211,12 +212,12 @@ def get_has_spin(self):
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down
11 changes: 6 additions & 5 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
List,
Optional,
Tuple,
Type,
Union,
)

Expand Down Expand Up @@ -262,7 +263,7 @@ def _init_attr(self):

@property
@lru_cache(maxsize=None)
def model_type(self) -> "DeepEvalWrapper":
def model_type(self) -> Type["DeepEvalWrapper"]:
"""Get type of model.
:type:str
Expand Down Expand Up @@ -693,13 +694,13 @@ def _get_natoms_and_nframes(
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
efield: Optional[np.ndarray] = None,
**kwargs: Dict[str, Any],
**kwargs: Any,
) -> Dict[str, np.ndarray]:
"""Evaluate the energy, force and virial by using this DP.
Expand Down Expand Up @@ -1023,7 +1024,7 @@ def _get_output_shape(self, odef, nframes, natoms):
def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def eval_descriptor(
def _eval_descriptor_inner(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_dim_aparam(self) -> int:
def eval(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = True,
fparam: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -276,7 +276,7 @@ def eval(
def eval_full(
self,
coords: np.ndarray,
cells: np.ndarray,
cells: Optional[np.ndarray],
atom_types: List[int],
atomic: bool = False,
fparam: Optional[np.array] = None,
Expand Down

0 comments on commit 81ec3b6

Please sign in to comment.